Training a Sequence Classifier

Let us now look into a short tutorial on training a sequence classifier using pre-trained language model.

For this tutorial, we provide a sample corpus in the folder demo/data/sentiment/.

1 $ ls demo/data/sentiment/
2 dev.txt
3 test.txt
4 train.txt

The train, dev, and test files are in tab separated format. The sample snippet of the train corpus is here. The first line of the file should contain sentence as the name of first column and Label as the name of the second column (which is also the column containing class labels)

1 $ cat demo/data/sentiment/train.txt
2 sentence    Label
3 I liked the movie   1
4 I hated the movie   0
5 The movie was good  1

The filenames should be the same as mentioned above

Hyper-Parameter Tuning

We first have to select the best hyper-parameter value. For this, we monitor the loss/accuracy/f1-score on the dev set and select the best hyper-parameter. We perform a grid-search over batch size and learning rate only.

Hyper-Parameter

Values

Batch Size

8, 16, 32

Learning Rate

1e-3, 1e-4, 1e-5, 1e-6, 3e-3, 3e-4, 3e-5, 13e-6, 5e-3, 5e-4, 5e-5, 5e-6

We now perform hyper-parameter tuning of the sequence classifier

1$ python src/sequenceclassifier/helper_scripts/tune_hyper_parameter.py \
2    --data_dir demo/data/sentiment/ \
3    --configuration_name bert-custom \
4    --model_name demo/model/mlm/checkpoint-200/ \
5    --output_dir demo/model/sentiment/ \
6    --tokenizer_name demo/model/tokenizer/ \
7    --task_name sentiment \
8    --log_dir logs

The code performs hyper-parameter tuning and Aim library tracks the experiment in logs folder

Fine-Tuning using best Hyper-Parameter

We now run the script src/sequenceclassifier/helper_scripts/get_best_hyper_parameter_and_train.py to find the best hyper-parameter and fine-tune the model using that best hyper-parameter

 1$ python src/sequenceclassifier/helper_scripts/get_best_hyper_parameter_and_train.py \
 2    --data_dir demo/data/sentiment/ \
 3    --configuration_name bert-custom \
 4    --model_name demo/model/mlm/checkpoint-200/ \
 5    --output_dir demo/model/sentiment/ \
 6    --tokenizer_name demo/model/tokenizer/ \
 7    --log_dir logs
 8
 9    +----+------------+-------------+----------------+
10    |    |   F1-Score |   BatchSize |   LearningRate |
11    +====+============+=============+================+
12    |  0 |   0.666667 |          16 |         0.001  |
13    +----+------------+-------------+----------------+
14    |  1 |   0.666667 |          16 |         0.0001 |
15    +----+------------+-------------+----------------+
16    |  2 |   0        |          16 |         1e-05  |
17    +----+------------+-------------+----------------+
18    |  3 |   0        |          16 |         1e-06  |
19    +----+------------+-------------+----------------+
20    |  4 |   0.666667 |          16 |         0.003  |
21    +----+------------+-------------+----------------+
22    |  5 |   0.666667 |          16 |         0.0003 |
23    +----+------------+-------------+----------------+
24    |  6 |   0        |          16 |         3e-05  |
25    +----+------------+-------------+----------------+
26    |  7 |   0        |          16 |         3e-06  |
27    +----+------------+-------------+----------------+
28    |  8 |   0        |          16 |         0.005  |
29    +----+------------+-------------+----------------+
30    |  9 |   0.666667 |          16 |         0.0005 |
31    +----+------------+-------------+----------------+
32    | 10 |   0        |          16 |         5e-05  |
33    +----+------------+-------------+----------------+
34    | 11 |   0        |          16 |         5e-06  |
35    +----+------------+-------------+----------------+
36    | 12 |   0.666667 |          32 |         0.001  |
37    +----+------------+-------------+----------------+
38    | 13 |   0.666667 |          32 |         0.0001 |
39    +----+------------+-------------+----------------+
40    | 14 |   0        |          32 |         1e-05  |
41    +----+------------+-------------+----------------+
42    | 15 |   0        |          32 |         1e-06  |
43    +----+------------+-------------+----------------+
44    | 16 |   0.666667 |          32 |         0.003  |
45    +----+------------+-------------+----------------+
46    | 17 |   0.666667 |          32 |         0.0003 |
47    +----+------------+-------------+----------------+
48    | 18 |   0        |          32 |         3e-05  |
49    +----+------------+-------------+----------------+
50    | 19 |   0        |          32 |         3e-06  |
51    +----+------------+-------------+----------------+
52    | 20 |   0        |          32 |         0.005  |
53    +----+------------+-------------+----------------+
54    | 21 |   0.666667 |          32 |         0.0005 |
55    +----+------------+-------------+----------------+
56    | 22 |   0        |          32 |         5e-05  |
57    +----+------------+-------------+----------------+
58    | 23 |   0        |          32 |         5e-06  |
59    +----+------------+-------------+----------------+
60    | 24 |   0.666667 |           8 |         0.001  |
61    +----+------------+-------------+----------------+
62    | 25 |   0.666667 |           8 |         0.0001 |
63    +----+------------+-------------+----------------+
64    | 26 |   0        |           8 |         1e-05  |
65    +----+------------+-------------+----------------+
66    | 27 |   0        |           8 |         1e-06  |
67    +----+------------+-------------+----------------+
68    | 28 |   0.666667 |           8 |         0.003  |
69    +----+------------+-------------+----------------+
70    | 29 |   0.666667 |           8 |         0.0003 |
71    +----+------------+-------------+----------------+
72    | 30 |   0        |           8 |         3e-05  |
73    +----+------------+-------------+----------------+
74    | 31 |   0        |           8 |         3e-06  |
75    +----+------------+-------------+----------------+
76    | 32 |   0        |           8 |         0.005  |
77    +----+------------+-------------+----------------+
78    | 33 |   0.666667 |           8 |         0.0005 |
79    +----+------------+-------------+----------------+
80    | 34 |   0        |           8 |         5e-05  |
81    +----+------------+-------------+----------------+
82    | 35 |   0        |           8 |         5e-06  |
83    +----+------------+-------------+----------------+
84    Model is demo/model/mlm/checkpoint-200/
85    Best Configuration is 16 0.001
86    Best F1 is 0.6666666666666666

The command fine-tunes the model for 5 different random seeds. The models can be found in the folder demo/model/sentiment/

1$ ls -lh demo/model/sentiment/ | grep '^d' | awk '{print $9}
2bert-custom-model_sentiment_16_0.001_4_1
3bert-custom-model_sentiment_16_0.001_4_2
4bert-custom-model_sentiment_16_0.001_4_3
5bert-custom-model_sentiment_16_0.001_4_4
6bert-custom-model_sentiment_16_0.001_4_5

The folder contains the following files

 1$ ls -lh demo/model/sentiment/bert-custom-model_sentiment_16_0.001_4_1/ | awk '{print $5, $9}'
 2386B all_results.json
 3700B config.json
 4219B eval_results.json
 541B predict_results_sentiment.txt
 63.6M pytorch_model.bin
 796B runs
 848B test_predictions.txt
 9147B test_results.json
10187B train_results.json
11808B trainer_state.json
122.9K training_args.bin

The files test_predictions.txt contains the predictions from the model on test set. Similarly, the files test_results.json and eval_results.json contains the results (F1-Score, Accuracy, etc) from the model on test and dev set respectively.

The sample snippet of the eval_results.jsom is presented here

 1$ head demo/model/ner/en/bert-custom-model_ner_16_1e-05_4_1/eval_results.json
 2{
 3"epoch": 4.0,
 4"eval_f1": 0.6666666666666666,
 5"eval_loss": 0.7115099430084229,
 6"eval_runtime": 0.0788,
 7"eval_samples": 6,
 8"eval_samples_per_second": 76.159,
 9"eval_steps_per_second": 12.693
10}

The scores are bad as we have trained on a tiny corpus. Training on a larger corpus should give good results.