Fine-Tuning DistilBERT for Question Answering


The transformers library provides a clean and well-documented interface for many popular transformer models. Not only it makes the source code easier to read and understand, it also provided a standardize way to interact with the model. You have seen in the previous post how to use a model such as DistilBERT for natural language processing tasks. In this post, you will learn how to fine-tune the model for your own purpose. This expands the use of the model from inference to training. Specifically, you will learn:

  • How to prepare the dataset for training
  • How to train a model using a helper library

Let’s get started.

Fine-Tuning DistilBERT for Question Answering
Photo by Lea Fabienne. Some rights reserved.

Overview

This post is divided into three parts; they are:

  • Fine-tuning DistilBERT for Custom Q&A
  • Dataset and Preprocessing
  • Running the Training

Fine-tuning DistilBERT for Custom Q&A

The simplest way to use a model in the transformers library is to create a pipeline, which hides many details about how to interact with it.

One reason you may not want to create a pipeline but to set up the model individually is that you want to fine-tune the model on your own dataset. This is impossible with a pipeline because you need to examine the model’s raw output with a loss function, which is usually hidden from the pipeline.

Usually, the pre-trained model is created using a general-purpose dataset. However, it may not work well for a specific domain, especially if the language in the domain is significantly different from the general usage. This is where fine-tuning may be attempted.

The difficulty in fine-tuning is probably the availability of a good dataset. This is usually very expensive and time-consuming to create. For illustration purposes, in the following, we use a general-purpose and publicly available dataset called SQuAD (Stanford Question Answering Dataset).

Thanks to the highly generalized and cleaned design of the transformers library, fine-tuning the model is straightforward. Below is an example of how to fine-tune the model on the SQuAD dataset:

This code is a bit complex. Let’s break it down step by step.

Dataset and Preprocessing

The SQuAD dataset is a popular dataset for question answering and it is available on the Hugging Face hub. You can load it using the load_dataset() function from Hugging Face’s datasets library.

Every dataset is different. But this particular dataset is dictionary-like with keys “title”, “context”, “question”, and “answers”. The “context” is a piece of moderately long text. The “question” is a question sentence. The “answers” is a dictionary with the key “text” and “answer_start“. The “text” maps to a short string that is the answer to the question. The “answer_start” maps to the start position of the answer in the context. The “title” can be ignored, as it provides the title of the article that the context is extracted from.

To use the dataset for training, you need to know how the model expects the input and what kind of output it produces. In the case of DistilBERT for question answering, the model is fixed by the implementation of the DistilBertForQuestionAnswering class unless you decide to write your own model implementation. In this class, the model expects the input as a sequence of integer token IDs and the output is two vectors of logits, one for the start position and one for the end position of the answer.

You can find the details of the input and output format of the model in the previous post. Or you can find the details in the DistilBertForQuestionAnswering class documentation.

In order to use the dataset for training, you need to do some preprocessing. This is to transform the dataset into a format that matches the model’s input and output format. The dataset object loaded from the Hugging Face hub allows you to do this with the map() method, in which the transformation is implemented as a custom function, preprocess_function().

Note that the preprocess_function() is to accept a batch from the dataset as you used batched=True in the map() method.

In the preprocess_function(), the tokenizer is invoked with the questions from examples["question"] and the context from examples["context"]. The question is stripped of extra spaces and the context is truncated to fit in the maximum length of 384 tokens. The use of a tokenizer in this function is different from what you have seen in the previous post:

Firstly, the tokenizer is invoked with a batch of questions and the context. For potentially ragged input, the tokenizer will pad the input to the maximum length of the batch. Secondly, with return_offsets_mapping=True, the tokenizer returns a dictionary with the keys “input_ids“, “attention_mask“, and “offset_mapping“. The “input_ids” is the sequence of integer token IDs. The “attention_mask” is a binary mask that indicates which tokens are real (1) and which are padded (0). The “offset_mapping” is what is added by setting return_offsets_mapping=True and it is a list of tuples that indicates the character positions (start and end offsets) of each token in the original text.

The input_ids from the tokenizer output concatenates the question and the context in the format of:

which is what the model expects. The answer from the dataset is a string, and the character offset from which the answer can be found in the original context. This is different from what the model produces, namely, the logits of token positions. Therefore, you used a for-loop in preprocess_function() to recreate the start and end token positions of the answer.

In this code, the tokenizer is invoked with additional arguments. Setting return_offsets_mapping=True will make the returned object contain offset_mapping, a list of tuples identifying the start and end positions of each token in each input text.

First, the offset_mapping is popped from the object returned by the tokenizer since it is not needed for the training. Then for each answer, you identified the character start and end offset from the context. You can verify this with the code like the following:

Even if you know the character offset, the model operates on token positions.

Recall that the tokenizer concatenated the question and the context. Fortunately, the tokenizer provided the clue to identify the start and end of the context from its output. In inputs.sequence_ids(i), it is a Python list of integers or None, corresponding to the element i of the batch. The list holds None for the position where a special token is found and an integer for which the position is a token from the actual input. In your use case, you invoked the tokenizer with question first and context the second, therefore integer 0 corresponds to the question and 1 corresponds to the context.

Therefore, you can identify the token start and end offset of the context by checking where the integer 1 first and last appears in the sequence_ids list:

Given you know the start and end token positions of the context, you still need to check if the answer is covered by any token. This is done in a loop by checking each token one by one. You use a for-loop to iterate over each pair of offsets and check if the start and end character positions of the answer are within any token. If so, the position of the token is remembered as the start_positions and end_positions. For answers not found (e.g., due to the context was clipped), they are set to 0.

At the end of the preprocess_function(), the object inputs is returned. It is dictionary-like with keys input_ids, attention_masks, start_positions, and end_positions. You must not change the name of these keys because the DistilBertForQuestionAnswering class expects such arguments in the forward() method.

The DistilBERT model expects you to call it with the arguments input_ids. And if you call with a large batch, attention_masks is required as well to tell which token in the input are paddings. If you call with the optional start and end positions, the cross-entropy loss will be computed as well. This is how the transformers library is designed to help you call the model in inference and training with the same interface.

Running the Training

To run this code, you need to install the following packages:

While you can expect the requirements of torch, transformers, and datasets. The accelerate package is a dependency when you use the Trainer class from the transformers library.

You may expect training a complex model like DistilBERT to require a lot of code. Indeed, it is not easy since you need to decide what optimizer to use, the number of epochs to train, and the hyperparameters such as batch size, learning rate, weight decay, etc. You even need to handle the checkpointing so that you can resume the training in case of interruption.

This is why the Trainer class is introduced. You just need to set up the training arguments, then set up the Trainer with the dataset, and then run the training:

The Trainer will handle the checkpointing, the logging, and the evaluation in one function call. You just need to save the fine-tuned model (together with the tokenizer since they are loaded together) in the Hugging Face format once the training is complete:

That’s all you need to do. Even if you did not specify using a GPU for the training, the Trainer will automatically discover the GPU on your system and use it to speed up the process. The code above, although not very long, is the complete code for fine-tuning DistilBERT on the SQuAD dataset.

If you run this code, you will expect the following output:

This takes some time to run, even if you use a decent GPU. However, you are fine-tuning a pre-trained model on the new dataset. This is orders of magnitude faster and easier than training from scratch.

Once you finished the training, you can load the model in your other project by using the path:

Please make sure that model_path is the correct path to find the saved model files from your project.

Further Reading

Below are some links to the documentation of the classes and methods used in this post:

Summary

In this post, you have learned how to fine-tune DistilBERT for a custom question-answering task. Even DistilBERT and question-answering is used as an example, you can apply the same process to other models and tasks. In particular, you learned:

  • How to prepare the dataset for training
  • How to train or fine-tune the model using the Trainer interface from the transformers library

 

Recent Articles

Related Stories

Leave A Reply

Please enter your comment!
Please enter your name here