Training a Sequence Labeler
Let us now look into a short tutorial on training a sequence labeler (token classifier) using pre-trained language model.
For this tutorial, we provide a sample corpus in the folder demo/data/ner/en/
. The data is taken from WikiANN-NER https://huggingface.co/datasets/wikiann
1 $ ls demo/data/ner/
2 en
3
4 $ ls demo/data/ner/en/
5 dev.csv
6 test.csv
7 train.csv
The train
, dev
, and test
files are in conll format. The sample snippet of the train corpus is here
1 $ cat demo/data/ner/en/train.csv
2 This O
3 is O
4 not O
5 Romeo B-PER
6 , O
7 he’s O
8 some O
9 other O
10 where. O
11
12 Your O
13 plantain O
14 leaf O
15 is O
16 excellent O
17 for O
18 that. O
Every word is present in it’s own file followed by either a space
or a tab
followed by the entity label. Successive sentences are separated by an empty line.
The filenames should be the same as mentioned above
Convert CoNLL file to JSON format
We need to convert the CoNLL file to JSON format so that we can easily load the model and perform training. We use the following script to perform the conversion.
1$ python src/tokenclassifier/helper_scripts/conll_to_json_converter.py \
2 --data_dir <path to folder containing CoNLL files> \
3 --column_number <column number containing the labels>
For our example, we run the following command
1$ python src/tokenclassifier/helper_scripts/conll_to_json_converter.py \
2 --data_dir demo/data/ner/en/ \
3 --column_number 1
Training a Token classifier
We could directly train a token classifier by specifying the hyper-parameters as follows
1$ python src/tokenclassifier/train_tc.py \
2 --data <path to data folder/huggingface dataset name> \
3 --model_name <model name or path> \
4 --tokenizer_name <Tokenizer name or path> \
5 --task_name <ner or pos> \
6 --output_dir <output folder where the model will be saved> \
7 --batch_size <batch size to be used> \
8 --learning_rate <learning rate to be used> \
9 --train_steps <maximum number of training steps> \
10 --eval_steps <steps after which evaluation on dev set is performed> \
11 --save_steps <steps after which the model is saved> \
12 --config_name <configuration name> \
13 --max_seq_len <Maximum Sequence Length after which the sequence is trimmed> \
14 --perform_grid_search <Perform grid search where only the result would be stored> \
15 --seed <random seed used> \
16 --eval_only <Perform evaluation only>
Hyper-Parameter Tuning
We first have to select the best hyper-parameter value. For this, we monitor the loss/accuracy/f1-score on the dev set and select the best hyper-parameter. We perform a grid-search over batch size
and learning rate
only.
Hyper-Parameter |
Values |
---|---|
Batch Size |
8, 16, 32 |
Learning Rate |
1e-3, 1e-4, 1e-5, 1e-6, 3e-3, 3e-4, 3e-5, 13e-6, 5e-3, 5e-4, 5e-5, 5e-6 |
We now perform hyper-parameter tuning of the sequence labeler
1$ python src/tokenclassifier/helper_scripts/tune_hyper_parameter.py \
2 --data_dir demo/data/ner/en/ \
3 --configuration_name bert-custom \
4 --model_name demo/model/mlm/checkpoint-200/ \
5 --output_dir demo/model/ner/en/ \
6 --tokenizer_name demo/model/tokenizer/ \
7 --log_dir logs
The code performs hyper-parameter tuning and Aim library tracks the experiment in logs
folder
Fine-Tuning using best Hyper-Parameter
We now run the script src/tokenclassifier/helper_scripts/get_best_hyper_parameter_and_train.py
to find the best hyper-parameter and fine-tune the model using that best hyper-parameter
1$ python src/tokenclassifier/helper_scripts/get_best_hyper_parameter_and_train.py \
2 --data_dir demo/data/ner/en/ \
3 --configuration_name bert-custom \
4 --model_name demo/model/mlm/checkpoint-200/ \
5 --output_dir demo/model/ner/en/ \
6 --tokenizer_name demo/model/tokenizer/ \
7 --log_dir logs
8
9 +----+------------+-------------+----------------+
10 | | F1-Score | BatchSize | LearningRate |
11 +====+============+=============+================+
12 | 0 | 0 | 16 | 0.001 |
13 +----+------------+-------------+----------------+
14 | 1 | 0.08 | 16 | 0.0001 |
15 +----+------------+-------------+----------------+
16 | 2 | 0.0833333 | 16 | 1e-05 |
17 +----+------------+-------------+----------------+
18 | 3 | 0.0833333 | 16 | 1e-06 |
19 +----+------------+-------------+----------------+
20 | 4 | 0 | 16 | 0.003 |
21 +----+------------+-------------+----------------+
22 | 5 | 0 | 16 | 0.0003 |
23 +----+------------+-------------+----------------+
24 | 6 | 0.0833333 | 16 | 3e-05 |
25 +----+------------+-------------+----------------+
26 | 7 | 0.0833333 | 16 | 3e-06 |
27 +----+------------+-------------+----------------+
28 | 8 | 0 | 16 | 0.005 |
29 +----+------------+-------------+----------------+
30 | 9 | 0 | 16 | 0.0005 |
31 +----+------------+-------------+----------------+
32 | 10 | 0.0833333 | 16 | 5e-05 |
33 +----+------------+-------------+----------------+
34 | 11 | 0.0833333 | 16 | 5e-06 |
35 +----+------------+-------------+----------------+
36 | 12 | 0 | 32 | 0.001 |
37 +----+------------+-------------+----------------+
38 | 13 | 0.08 | 32 | 0.0001 |
39 +----+------------+-------------+----------------+
40 | 14 | 0.0833333 | 32 | 1e-05 |
41 +----+------------+-------------+----------------+
42 | 15 | 0.0833333 | 32 | 1e-06 |
43 +----+------------+-------------+----------------+
44 | 16 | 0 | 32 | 0.003 |
45 +----+------------+-------------+----------------+
46 | 17 | 0 | 32 | 0.0003 |
47 +----+------------+-------------+----------------+
48 | 18 | 0.0833333 | 32 | 3e-05 |
49 +----+------------+-------------+----------------+
50 | 19 | 0.0833333 | 32 | 3e-06 |
51 +----+------------+-------------+----------------+
52 | 20 | 0 | 32 | 0.005 |
53 +----+------------+-------------+----------------+
54 | 21 | 0 | 32 | 0.0005 |
55 +----+------------+-------------+----------------+
56 | 22 | 0.0833333 | 32 | 5e-05 |
57 +----+------------+-------------+----------------+
58 | 23 | 0.0833333 | 32 | 5e-06 |
59 +----+------------+-------------+----------------+
60 | 24 | 0 | 8 | 0.001 |
61 +----+------------+-------------+----------------+
62 | 25 | 0.08 | 8 | 0.0001 |
63 +----+------------+-------------+----------------+
64 | 26 | 0.0833333 | 8 | 1e-05 |
65 +----+------------+-------------+----------------+
66 | 27 | 0.0833333 | 8 | 1e-06 |
67 +----+------------+-------------+----------------+
68 | 28 | 0 | 8 | 0.003 |
69 +----+------------+-------------+----------------+
70 | 29 | 0 | 8 | 0.0003 |
71 +----+------------+-------------+----------------+
72 | 30 | 0.0833333 | 8 | 3e-05 |
73 +----+------------+-------------+----------------+
74 | 31 | 0.0833333 | 8 | 3e-06 |
75 +----+------------+-------------+----------------+
76 | 32 | 0 | 8 | 0.005 |
77 +----+------------+-------------+----------------+
78 | 33 | 0 | 8 | 0.0005 |
79 +----+------------+-------------+----------------+
80 | 34 | 0.0833333 | 8 | 5e-05 |
81 +----+------------+-------------+----------------+
82 | 35 | 0.0833333 | 8 | 5e-06 |
83 +----+------------+-------------+----------------+
84 Model is demo/model/mlm/checkpoint-200/
85 Best Configuration is 16 1e-05
86 Best F1 is 0.08333333333333334
The command fine-tunes the model for 5
different random seeds. The models can be found in the folder demo/model/ner/en/
.
1$ ls -lh demo/model/ner/en/ | grep '^d' | awk '{print $9}
2bert-custom-model_ner_16_1e-05_4_1
3bert-custom-model_ner_16_1e-05_4_2
4bert-custom-model_ner_16_1e-05_4_3
5bert-custom-model_ner_16_1e-05_4_4
6bert-custom-model_ner_16_1e-05_4_5
The folder contains the following files
1$ ls -lh demo/model/ner/en/bert-custom-model_ner_16_1e-05_4_1/ | awk '{print $5, $9}'
2224B GOAT
3884B config.json
4417B dev_predictions.txt
5188B dev_results.txt
63.6M pytorch_model.bin
796B runs
8262B test_predictions.txt
9169B test_results.txt
102.9K training_args.bin
The files test_predictions.txt
and dev_predictions.txt
contains the predictions from the model on test
and dev
set respectively.
Similarly, the files test_results.txt
and dev_results.txt
contains the results (F1-Score, Accuracy, etc) from the model on test
and dev
set respectively.
The sample snippet of the test_predictions.txt
and dev_predictions.txt
are presented here
1$ head demo/model/ner/en/bert-custom-model_ner_16_1e-05_4_1/test_predictions.txt
2This O O
3is O O
4not O O
5Romeo B-PER O
6, O O
7he’s O O
8some O O
9other O O
10where. O O
The first column is the word, second column is the ground truth, and the third column is the predicted label.
1$ head demo/model/ner/en/bert-custom-model_ner_16_1e-05_4_1/test_results.txt
2test_loss = 1.888014554977417
3test_precision = 0.0
4test_recall = 0.0
5test_f1 = 0.0
6test_runtime = 0.0331
7test_samples_per_second = 60.493
8test_steps_per_second = 30.246
The scores are bad as we have trained on a tiny corpus. Training on a larger corpus should give good results.