Training a Sequence Labeler

Let us now look into a short tutorial on training a sequence labeler (token classifier) using pre-trained language model.

For this tutorial, we provide a sample corpus in the folder demo/data/ner/en/. The data is taken from WikiANN-NER https://huggingface.co/datasets/wikiann

1 $ ls demo/data/ner/
2 en
3
4 $ ls demo/data/ner/en/
5 dev.csv
6 test.csv
7 train.csv

The train, dev, and test files are in conll format. The sample snippet of the train corpus is here

 1 $ cat demo/data/ner/en/train.csv
 2 This        O
 3 is  O
 4 not O
 5 Romeo       B-PER
 6 ,   O
 7 he’s        O
 8 some        O
 9 other       O
10 where.      O
11
12 Your        O
13 plantain    O
14 leaf        O
15 is  O
16 excellent   O
17 for O
18 that.       O

Every word is present in it’s own file followed by either a space or a tab followed by the entity label. Successive sentences are separated by an empty line.

The filenames should be the same as mentioned above

Convert CoNLL file to JSON format

We need to convert the CoNLL file to JSON format so that we can easily load the model and perform training. We use the following script to perform the conversion.

1$ python src/tokenclassifier/helper_scripts/conll_to_json_converter.py \
2    --data_dir <path to folder containing CoNLL files> \
3    --column_number <column number containing the labels>

For our example, we run the following command

1$ python src/tokenclassifier/helper_scripts/conll_to_json_converter.py \
2    --data_dir demo/data/ner/en/ \
3    --column_number 1

Training a Token classifier

We could directly train a token classifier by specifying the hyper-parameters as follows

 1$ python src/tokenclassifier/train_tc.py \
 2    --data <path to data folder/huggingface dataset name> \
 3    --model_name <model name or path> \
 4    --tokenizer_name <Tokenizer name or path> \
 5    --task_name <ner or pos> \
 6    --output_dir <output folder where the model will be saved> \
 7    --batch_size <batch size to be used> \
 8    --learning_rate <learning rate to be used> \
 9    --train_steps <maximum number of training steps> \
10    --eval_steps <steps after which evaluation on dev set is performed> \
11    --save_steps <steps after which the model is saved> \
12    --config_name <configuration name> \
13    --max_seq_len <Maximum Sequence Length after which the sequence is trimmed> \
14    --perform_grid_search <Perform grid search where only the result would be stored> \
15    --seed <random seed used> \
16    --eval_only <Perform evaluation only>

Hyper-Parameter Tuning

We first have to select the best hyper-parameter value. For this, we monitor the loss/accuracy/f1-score on the dev set and select the best hyper-parameter. We perform a grid-search over batch size and learning rate only.

Hyper-Parameter

Values

Batch Size

8, 16, 32

Learning Rate

1e-3, 1e-4, 1e-5, 1e-6, 3e-3, 3e-4, 3e-5, 13e-6, 5e-3, 5e-4, 5e-5, 5e-6

We now perform hyper-parameter tuning of the sequence labeler

1$ python src/tokenclassifier/helper_scripts/tune_hyper_parameter.py \
2    --data_dir demo/data/ner/en/ \
3    --configuration_name bert-custom \
4    --model_name demo/model/mlm/checkpoint-200/ \
5    --output_dir demo/model/ner/en/ \
6    --tokenizer_name demo/model/tokenizer/ \
7    --log_dir logs

The code performs hyper-parameter tuning and Aim library tracks the experiment in logs folder

Fine-Tuning using best Hyper-Parameter

We now run the script src/tokenclassifier/helper_scripts/get_best_hyper_parameter_and_train.py to find the best hyper-parameter and fine-tune the model using that best hyper-parameter

 1$ python src/tokenclassifier/helper_scripts/get_best_hyper_parameter_and_train.py \
 2    --data_dir demo/data/ner/en/ \
 3    --configuration_name bert-custom \
 4    --model_name demo/model/mlm/checkpoint-200/ \
 5    --output_dir demo/model/ner/en/ \
 6    --tokenizer_name demo/model/tokenizer/ \
 7    --log_dir logs
 8
 9    +----+------------+-------------+----------------+
10    |    |   F1-Score |   BatchSize |   LearningRate |
11    +====+============+=============+================+
12    |  0 |  0         |          16 |         0.001  |
13    +----+------------+-------------+----------------+
14    |  1 |  0.08      |          16 |         0.0001 |
15    +----+------------+-------------+----------------+
16    |  2 |  0.0833333 |          16 |         1e-05  |
17    +----+------------+-------------+----------------+
18    |  3 |  0.0833333 |          16 |         1e-06  |
19    +----+------------+-------------+----------------+
20    |  4 |  0         |          16 |         0.003  |
21    +----+------------+-------------+----------------+
22    |  5 |  0         |          16 |         0.0003 |
23    +----+------------+-------------+----------------+
24    |  6 |  0.0833333 |          16 |         3e-05  |
25    +----+------------+-------------+----------------+
26    |  7 |  0.0833333 |          16 |         3e-06  |
27    +----+------------+-------------+----------------+
28    |  8 |  0         |          16 |         0.005  |
29    +----+------------+-------------+----------------+
30    |  9 |  0         |          16 |         0.0005 |
31    +----+------------+-------------+----------------+
32    | 10 |  0.0833333 |          16 |         5e-05  |
33    +----+------------+-------------+----------------+
34    | 11 |  0.0833333 |          16 |         5e-06  |
35    +----+------------+-------------+----------------+
36    | 12 |  0         |          32 |         0.001  |
37    +----+------------+-------------+----------------+
38    | 13 |  0.08      |          32 |         0.0001 |
39    +----+------------+-------------+----------------+
40    | 14 |  0.0833333 |          32 |         1e-05  |
41    +----+------------+-------------+----------------+
42    | 15 |  0.0833333 |          32 |         1e-06  |
43    +----+------------+-------------+----------------+
44    | 16 |  0         |          32 |         0.003  |
45    +----+------------+-------------+----------------+
46    | 17 |  0         |          32 |         0.0003 |
47    +----+------------+-------------+----------------+
48    | 18 |  0.0833333 |          32 |         3e-05  |
49    +----+------------+-------------+----------------+
50    | 19 |  0.0833333 |          32 |         3e-06  |
51    +----+------------+-------------+----------------+
52    | 20 |  0         |          32 |         0.005  |
53    +----+------------+-------------+----------------+
54    | 21 |  0         |          32 |         0.0005 |
55    +----+------------+-------------+----------------+
56    | 22 |  0.0833333 |          32 |         5e-05  |
57    +----+------------+-------------+----------------+
58    | 23 |  0.0833333 |          32 |         5e-06  |
59    +----+------------+-------------+----------------+
60    | 24 |  0         |           8 |         0.001  |
61    +----+------------+-------------+----------------+
62    | 25 |  0.08      |           8 |         0.0001 |
63    +----+------------+-------------+----------------+
64    | 26 |  0.0833333 |           8 |         1e-05  |
65    +----+------------+-------------+----------------+
66    | 27 |  0.0833333 |           8 |         1e-06  |
67    +----+------------+-------------+----------------+
68    | 28 |  0         |           8 |         0.003  |
69    +----+------------+-------------+----------------+
70    | 29 |  0         |           8 |         0.0003 |
71    +----+------------+-------------+----------------+
72    | 30 |  0.0833333 |           8 |         3e-05  |
73    +----+------------+-------------+----------------+
74    | 31 |  0.0833333 |           8 |         3e-06  |
75    +----+------------+-------------+----------------+
76    | 32 |  0         |           8 |         0.005  |
77    +----+------------+-------------+----------------+
78    | 33 |  0         |           8 |         0.0005 |
79    +----+------------+-------------+----------------+
80    | 34 |  0.0833333 |           8 |         5e-05  |
81    +----+------------+-------------+----------------+
82    | 35 |  0.0833333 |           8 |         5e-06  |
83    +----+------------+-------------+----------------+
84    Model is demo/model/mlm/checkpoint-200/
85    Best Configuration is 16 1e-05
86    Best F1 is 0.08333333333333334

The command fine-tunes the model for 5 different random seeds. The models can be found in the folder demo/model/ner/en/.

1$ ls -lh demo/model/ner/en/ | grep '^d' | awk '{print $9}
2bert-custom-model_ner_16_1e-05_4_1
3bert-custom-model_ner_16_1e-05_4_2
4bert-custom-model_ner_16_1e-05_4_3
5bert-custom-model_ner_16_1e-05_4_4
6bert-custom-model_ner_16_1e-05_4_5

The folder contains the following files

 1$ ls -lh demo/model/ner/en/bert-custom-model_ner_16_1e-05_4_1/ | awk '{print $5, $9}'
 2224B GOAT
 3884B config.json
 4417B dev_predictions.txt
 5188B dev_results.txt
 63.6M pytorch_model.bin
 796B runs
 8262B test_predictions.txt
 9169B test_results.txt
102.9K training_args.bin

The files test_predictions.txt and dev_predictions.txt contains the predictions from the model on test and dev set respectively. Similarly, the files test_results.txt and dev_results.txt contains the results (F1-Score, Accuracy, etc) from the model on test and dev set respectively.

The sample snippet of the test_predictions.txt and dev_predictions.txt are presented here

 1$ head demo/model/ner/en/bert-custom-model_ner_16_1e-05_4_1/test_predictions.txt
 2This O O
 3is O O
 4not O O
 5Romeo B-PER O
 6, O O
 7he’s O O
 8some O O
 9other O O
10where. O O

The first column is the word, second column is the ground truth, and the third column is the predicted label.

1$ head demo/model/ner/en/bert-custom-model_ner_16_1e-05_4_1/test_results.txt
2test_loss = 1.888014554977417
3test_precision = 0.0
4test_recall = 0.0
5test_f1 = 0.0
6test_runtime = 0.0331
7test_samples_per_second = 60.493
8test_steps_per_second = 30.246

The scores are bad as we have trained on a tiny corpus. Training on a larger corpus should give good results.