These plots suggest that when a dataset’s Rg distribution covers multiple orders of magnitude or has non-negligible representation in both the Rg>1 and Rg<1 regions (such as in the case with OpenOrca and other datasets with R̅g>1) the distribution can become highly skewed. As a result, the arithmetic mean may be disproportionately influenced by larger values, potentially misrepresenting the distribution’s central tendency. In such cases, computing the mean in log-space (then optionally transforming it back to the original scale) might provide a more meaningful summary statistic. In other words, it could make sense to use the geometric mean:
The RACE Reading Comprehension Dataset
Based on the above R̅g table, I decided the RACE ReAding Comprehension Dataset from Examinations (R̅g=0.01) would be a good candidate for investigation. Multiple choice QA seemed like an ideal test-bed for exploring the effects of prompt-masking, since the prompt is naturally very long relative to the completion. Regardless of prompt length, the completion is always 1 character long, namely A, B, C or D (if you ignore special tokens, delimiters, etc). My hunch was that if there are any effects from modulating prompt token weights, they would certainly be noticeable here.
As stated in the dataset card:
RACE is a large-scale reading comprehension dataset with more than 28,000 passages and nearly 100,000 questions. The dataset is collected from English examinations in China, which are designed for middle school and high school students. The dataset can be served as the training and test sets for machine comprehension.
The QA schema is simple: the prompt presents a question, possibly some context (the article field), and then lists four options. The completion (answer) is always one of: A, B, C, D. This dataset viewer hosted on HuggingFace allows browsing the full set, but here’s a small example:
Before we jump into the full implementation of prompt-loss-weight, and try it out on the RACE data, we need a basic understanding of loss and where it comes from. Simply put, loss is a measure of how well our model (LLM) “fits” (explains, predicts) our data. During fine-tuning (and also pre-training), we “move” the model closer to the data by tweaking the network weights in such a way that decreases the loss. The chain rule (of calculus) gives us a precise algorithm for computing these tweaks, given the loss function and the network architecture.
The most common loss function in LLM fine-tuning is called Cross Entropy Loss (CEL). For this reason, most discussions of CEL are framed around the definition of cross-entropy, which comes from information theory. While it’s true that “cross-entropy” is right there in the name, a more intuitive understanding can be achieved when approaching CEL through the lens of maximum likelihood estimation (MLE). I’ll try to explain it from both angles.
We have already established that LLMs are wired for next token prediction. What this means is that the LLM is basically just a mathematical function that takes as input a sequence of tokens, and outputs a conditional probability distribution for the next token over the entire token vocabulary V. In other words, it outputs a vector of probability values of dimension |V| that sums to 1. (in set notation |S| denotes the number of elements, or cardinality, of a set S)
Let’s take a small toy example to illustrate how this works. Imagine that our training data contains the 4-token sequence: The bird flew away
. Given the first 3 tokens (The bird flew
), an LLM might output the following vector of probabilities for every possible 4ᵗʰ token — for the sake of simplicity, we’ll imagine that the 5 candidate tokens listed (in magenta) are the only possibilities (i.e. |V|=5). The function p(⋅) represents the conditional probabilities output by the LLM (notice they sum to 1):
When training (or fine-tuning) an LLM on a token sequence, we step through the sequence token-by-token and compare the next-token-distribution generated by the LLM to the actual next token in the sequence, and from there we calculate the CEL for that token.
Notice here that the actual 4ᵗʰ token in the sequence (away
) does not have the highest probability in the table. During training, we would like to tweak the weights slightly so as to increase the probability of away
, while decreasing the others. The key is having the right loss function… it allows us to compute exactly how much to tweak each weight, for each token.
Once the loss is computed for each token, the final loss is computed as the average per-token-loss over all tokens. But first we must establish the formula for this per-token-loss.
Information Theory Interpretation
Continuing the toy problem, to compute CEL for the 4ᵗʰ token position, we compare the actual 4ᵗʰ token to the generated distribution p(⋅) over all 5 possible 4ᵗʰ tokens. In fact, we treat the actual 4ᵗʰ token as a distribution q(⋅) in its own right (albeit a degenerate one) that has a value of 1 for the token appearing in the data –away
– and a value of 0 for all other possible 4ᵗʰ tokens (this is sometimes called one-hot encoding).
The reason we contort the training data into this strange one-hot encoded probability representation q(⋅) is so we can apply the formula for cross-entropy, which is a measure of the divergence between two discrete probability distributions (BTW, not symmetric w.r.t. q,p):
where x indexes over all possible states (i.e. 5 tokens). This works out to:
So basically CEL is just using the q vector to select from the p vector the single value corresponding to the token that actually appears in the data –away
– (i.e. multiplying it by 1), and throwing away all other values (i.e. multiplying by 0). So we are indexing over all possible states (tokens) only to select one and ignore the rest.
MLE Interpretation
When fine-tuning an LLM, we seek the LLM weights θ that maximize the probability of the training data given those weights, often called the likelihood of the weights ℒ(θ) = ℙ(D|θ). And so we require an expression for this quantity. Luckily, there’s an easy way to compute this from next token probabilities, which the LLM already gives us.
Starting with the other chain rule (of probability), we decompose the joint probability of a token sequence S into a product of conditional probabilities:
This decomposition establishes the connection between next-token-prediction and the joint probability of the full token sequence — the joint probability is just the product of all the conditionals.
Using i to index over the tokens of a token sequence S = (t₁,t₂,t₃,…, tᵢ ,…), we’ll use the following shorthand to denote the conditional probability output by an LLM for the iᵗʰ token in a sequence, given the LLM weights θ and the previous i-1 tokens:
It should be emphasized that pᵢ is not a vector here (i.e. a distribution over all possible next tokens) but represents only the probability computed for the actual iᵗʰ token, i.e. the yellow highlighted row in the above example.
If we take the logarithm of the joint probability of a sequence, a product becomes a sum (since log is monotonic, this doesn’t affect optimization):
Now we can connect the final sum-of-logs expression (right here☝)️ to the formula for Average Cross Entropy Loss L over a token sequence:
which is the causal language model objective function. Often the “Average” is dropped from the name, and it’s just called “Cross Entropy Loss,” but it’s good to remember that CEL is technically computed at the token level, and then averaged across tokens. From this final expression it should hopefully be clear that minimizing the CEL is equivalent to maximizing the probability of the token sequence, which is what MLE seeks.
One convenience resulting from the form of this expression is that it is very easy to modify if we want to compute the loss over any subset of the tokens. Recall that we may sometimes be interested in finding the LLM weights θ that maximize the probability of the completion given the prompt:
We could easily adjust the loss for this scenario by simply averaging only over the completion tokens. If we use “𝕀c” to denote the set of all completion token indices, then we can express completion loss as:
Since the loss for each token is already conditioned on all previous tokens in the sequence, this means that the prompt is automatically accounted for in the conditional, even if we average over completion tokens only.
Now that we have established CEL as an average of per-token losses over a token sequence, we can define the weighted average version of CEL:
Depending how we set the weights wᵢ, we can use this formula to define multiple losses. For example, if we set all weights wᵢ =1 then we recover the standard, full sequence CEL from before. However, if we set wᵢ =1 only for completion tokens, and wᵢ = 0 for prompt tokens, then we get completion loss. And likewise, prompt loss is defined by setting wᵢ =1 only over prompt tokens, and wᵢ = 0 otherwise.
Since we rarely (if ever) want to down-weight the completion tokens, we fix the completion token weights at wᵢ =1, but for the prompt tokens we can define a continuous value on the [0:1] interval called prompt_loss_weight
. This way we can tune how much to weight the prompt tokens during training, from wᵢ = 0 (completion loss) all the way to wᵢ =1 (standard full sequence loss). Or, we could even use wᵢ =0.1 to give the prompt tokens a small but non-zero weight.
Loss Implementation
Let’s take a look under the hood at how loss is normally computed in the HuggingFace transformers package. Since we’ll be fine-tuning the Llama-2–7b-chat-hf model in our experiments, we’ll look at LlamaForCausalLM, specifically at the forward pass, where loss is computed during training.
Recall that loss is a way of comparing each actual token to the LLM’s prediction for that token (given the preceding actual tokens) — and so the loss function needs access to these two data structures. In this case, loss is fed two tensors: logits
and labels
. The labels
tensor holds the actual tokens (token ids to be exact). Thelogits
tensor holds the predicted next-token-probabilities, prior to softmax normalization (which forces them to sum to 1 — it turns out that it’s more efficient to leave these values in their raw, pre-normalized form).
The logits
tensor is 3D, with shape [B,N,|V|]
, where B
is batch size, N
is sequence length (in tokens), and |V|
is token vocabulary size. The 2D labels
tensor just contains the token sequence itself, so it has shape [B,N]
. Here is the key section of code where CEL is normally computed:
# Shift-by-1 so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()# Flatten the tensors
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
# Compute loss
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)
For each position i along the 2nd dimension of logits
, this tensor contains probabilities for predicting the next token (token i+1) given all the preceding tokens up through the iᵗʰ token. These probabilities need to be compared to the actual i+1ˢᵗ token in labels
. This is why the shift-by-1 happens in the first several lines — to bring these two values into alignment for each token.