Transformer models are the standard models to use for NLP tasks today. Almost all of the NLP tasks involve generating text but it is not the direct output of the model. You may expect the model to help you generate text that is coherent and contextually relevant. While partially this is related to the quality of the model, the generation parameters also play a crucial role in the quality of the generated text.
In this post, you will explore the key parameters that control text generation in transformer models. You will see how these parameters affect the quality of the generated text and how to tune them for different applications. In particular, you will learn:
- The core parameters that control text generation in transformer models
- The different decoding strategies
- How to control the creativity and coherence of generated text
- How to fine-tune generation parameters for specific applications
Let’s get started!
Understanding Text Generation Parameters in Transformers
Photo by Anton Klyuchnikov. Some rights reserved.
Overview
This post is divided into seven parts; they are:
– Core Text Generation Parameters
– Experimenting with Temperature
– Top-K and Top-P Sampling
– Controlling Repetition
– Greedy Decoding and Sampling
– Parameters for Specific Applications
– Beam Search and Multiple Sequences Generation
Core Text Generation Parameters
Let’s pick the GPT-2 model as an example. It is a small transformer model that does not require a lot of computational resources but is still capable of generating high-quality text. A simple example to generate text using the GPT-2 model is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer
# create model and tokenizer tokenizer = GPT2Tokenizer.from_pretrained(“gpt2”) model = GPT2LMHeadModel.from_pretrained(“gpt2”)
# tokenize input prompt to sequence of ids prompt = “Artificial intelligence is” inputs = tokenizer(prompt, return_tensors=“pt”) # generate output as a sequence of token ids output = model.generate( **inputs, max_length=50, num_return_sequences=1, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) # convert token ids into text strings generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f“Prompt: {prompt}”) print(“Generated Text:”) print(generated_text) |
If you run this code, you may see:
Prompt: Artificial intelligence is Generated Text: Artificial intelligence is used in the production of technology, the delivery of which is determined by technological change. For example, an autonomous car can change its steering wheel to help avoid driving traffic. In the case of artificial intelligence, this can change what consumers |
You provided a prompt of only three words, and the model generated a long piece of text. This is not generated in one shot, but the model is invoked multiple times in an iterative process.
You can see the numerous parameters used in the generate()
function. The first one you used is max_length
. Trivially, this controls how long the generated text should be, in number of tokens. Usually, the model is generating one token at a time using the prompt as context. Then, append the newly generated token to the prompt and generate the next token. Therefore, the longer you want the generated text to be, the more time it takes to generate it. Note that it is tokens in concern, not words, because you used a subword tokenizer with the GPT-2 model. One token may be just a subword unit, not a full word.
However, the model is not generating any single token specifically. Instead, it is generating a “logit”, which is a vector of probabilities of the next token. The logit is a long vector, exactly as long as the size of the vocabulary. Given it is a probability distribution over all the possible “next tokens”, you can pick the token with the highest probability (when you set do_sample=False
), or any other token with non-zero probability (when you set do_sample=True
). This is what all other parameters are for.
The temperature
parameter skews the probability distribution. A lower temperature emphasizes the most likely token, while a higher temperature diminishes the difference between a likely and unlikely token. The default temperature is 1.0, and it should be a positive value. The top_k
parameter then selects only the top $k$ tokens rather than the entire vocabulary of tokens. Then the probability is recalculated to sum to 1. Next, if top_p
is set, this set of $k$ tokens is further filtered to keep the top ones that make up the total probability of $p$. This final set of tokens is then used to sample the next token, and this process is called the **nucleus sampling**.
Remember that you’re generating a sequence of tokens, one at a time. Chances are that you will see the same token repeatedly in every step, and you may see the same token produced in the sequence. It is usually not what you want, so you may want to decrease the probability of those tokens when you see them again. That is what the repetition_penalty
parameter is for.
Experimenting with Temperature
Given you know what the various parameters do, let’s see how the output changes when you adjust some of them.
The temperature parameter has a significant impact on the creativity and randomness of the generated text. You can see its effect with the following example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(“gpt2”) model = GPT2LMHeadModel.from_pretrained(“gpt2”)
prompt = “The future of artificial intelligence is” inputs = tokenizer(prompt, return_tensors=“pt”)
# Generate text with different temperature values temperatures = [0.2, 0.5, 1.0, 1.5] print(f“Prompt: {prompt}”) for temp in temperatures: print() print(f“Temperature: {temp}”) output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=temp, top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(“Generated Text:”) print(generated_text) |
When you run this code, you may see:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
Prompt: The future of artificial intelligence is
Temperature: 0.2 Generated Text: The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial intelligence is uncertain.
The future of artificial intelligence is uncertain. The future
Temperature: 0.5 Generated Text: The future of artificial intelligence is uncertain.
“There is a lot of work to be done on this,” said Eric Schmitt, a professor of computer science and engineering at the University of California, Berkeley.
“We’re looking for a way to make AI more like computers. We need to take a step back and look at how we think about it and how we interact with it.”
Schmitt said he’s confident that artificial intelligence will eventually be able to do more than
Temperature: 1.0 Generated Text: The future of artificial intelligence is not yet clear, however.”
“Is the process that we are trying to do through computer vision and the ability to look at a person at multiple points without any loss of intelligence due to not seeing a person at multiple points?” asked Richard. “I also think the people who are doing this research are extremely interesting to me due to being able to see humans at a range of different points in time. In particular, they’ve shown how to do a pretty complex
Temperature: 1.5 Generated Text: The future of artificial intelligence is an era to remember as much as Google in search results, particularly ones not supported by much else for some years — and it might look like the search giant is now just as good without artificial intelligence. [Graphic image from Shutterstock] |
With a low temperature (e.g., 0.2), the text becomes more focused and deterministic, often sticking to common phrases and conventional ideas. You also see that it keeps repeating the same sentence because the probability is concentrated on a few tokens, limiting diversity. This can be resolved by using the repetition penalty parameter that is covered in a section below.
With a medium temperature (e.g., 0.5 to 1.0), the text has a good balance of coherence and creativity. The generated text may not be factual, but the language is natural.
With a high temperature (e.g., 1.5), the text becomes more random and creative, but may also be less coherent and sometimes illogical. The language may be difficult to understand, just like the example above.
Choosing the right temperature depends on your application. If you are creating a helper for code completion or writing, a lower temperature is often better. For creative writing or brainstorming, a higher temperature can produce more diverse and interesting results.
Top-K and Top-P Sampling
The nucleus sampling parameters control how flexible you allow the model to pick the next token. Should you adjust the top_k
parameter or the top_p
parameter? Let’s see their effect in an example:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(“gpt2”) model = GPT2LMHeadModel.from_pretrained(“gpt2”)
prompt = “The best way to learn programming is” inputs = tokenizer(prompt, return_tensors=“pt”)
# Generate text with different top_k values top_k_values = [5, 20, 50] print(f“Prompt: {prompt}”)
for top_k in top_k_values: print() print(f“Top-K = {top_k}”) output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=1.0, top_k=top_k, top_p=1.0, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(“Generated Text:”) print(generated_text)
# Generate text with different top_p values top_p_values = [0.5, 0.7, 0.9] for top_p in top_p_values: print() print(f“Top-P = {top_p}”) output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=1.0, top_k=0, top_p=top_p, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(“Generated Text:”) print(generated_text) |
When you run this code, you may see:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
Prompt: The best way to learn programming is
Top–K = 5 Generated Text: The best way to learn programming is to be able to learn the basics in a very short amount of time, and then learn to use them effectively and quickly.
If you want to be a successful programmer in this way, you should learn to use the techniques in the above video to learn the basics of programming.
If you want to learn to code more effectively, you can also get more experienced programmers by doing the following:
Learning to Code
Learning to code is very
Top–K = 20 Generated Text: The best way to learn programming is to learn it.
In order to get started with Ruby you‘re going to have to make a few mistakes, some of them can be fairly obvious.
First of all, you’re going to have to write a function that takes in a value. What this means is that you‘re going to make a new instance of the Ruby function. You can read more about this in Part 1 of this course, or just try it out from the REPL.
Top-K = 50 Generated Text: The best way to learn programming is to become familiar with the language and the software. One of the first and most common forms of programming is to create, modify, and distribute code.
However, there are very few programming libraries that can provide us with all that we need.
The following sample programming program uses some of the above, but does not show the best way to learn programming. It was written in Java and in C or C++.
The original source code is
Top-P = 0.5 Generated Text: The best way to learn programming is to be able to create a tool for you. That’s what I do.
That‘s why I’m here today.
I‘m here to talk about the basics of programming, and I’m going to tell you how to learn programming.
I‘m here to talk about learning programming.
It’s easy to forget that you don‘t have to know how to program. It’s easy to forget that you don‘t have to know how
Top-P = 0.7 Generated Text: The best way to learn programming is to practice programming. Learn the principles of programming by observing and performing exercises.
I used to work in a world of knowledge which included all sorts of things, and was able to catch up on them and understand them from their perspective. For instance, I learned to sit up straight and do five squats. Then, I would have to practice some type of overhead training. I would try to learn the best technique and add that to my repertoire.
What
Top-P = 0.9 Generated Text: The best way to learn programming is to become a good hacker. Don’t use any programming tools. Just a regular dot–com user, an occasional coding learner, and stick with it.
— Victoria E. Nichols |
You can see that with a small $k$ value, such as 5, the model has fewer options to choose from, resulting in more predictable text. At the extreme, when $k=1$, the model always picks the single token with the highest probability, which is greedy decoding, and typically produces poor output. With a larger $k$, such as 50, the model has more options to choose from, resulting in more diverse text.
Similarly, for the top_p
parameter, a smaller $p$ means the model selects from a smaller set of high-probability tokens, resulting in more focused text. With a larger $p$, such as 0.9, the model has a wider selection, potentially leading to more varied text. However, how many options you may pick for a given $p$ is not fixed. It depends on the probability distribution as the model predicted. When the model is very confident about the next token (such as limited by some grammar rules), only a very small set of tokens is allowed. This adaptive nature is also why top-p sampling is often preferred over top-k sampling.
Controlling Repetition
Repetition is a common issue in text generation. The repetition_penalty
parameter helps address this by penalizing tokens that have already appeared in the generated text. Let’s see how it works:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(“gpt2”) model = GPT2LMHeadModel.from_pretrained(“gpt2”)
prompt = “Once upon a time, there was a” inputs = tokenizer(prompt, return_tensors=“pt”)
# Generate text with different repetition penalties penalties = [1.0, 1.2, 1.5, 2.0] print(f“Prompt: {prompt}”) for penalty in penalties: print() print(f“Repetition penalty: {penalty}”) output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=0.3, top_k=50, top_p=1.0, repetition_penalty=penalty, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(“Generated Text:”) print(generated_text) |
When you run this code, you may see:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
Prompt: Once upon a time, there was a
Repetition penalty: 1.0 Generated Text: Once upon a time, there was a great deal of confusion about what was going on. The first thing that came to mind was the fact that the government had already been in place for a long time, and that the government had been in place for a long time. And it was clear that the government had been in place for a long time. And it was clear that the government had been in place for a long time. And it was clear that the government had been in place for a long
Repetition penalty: 1.2 Generated Text: Once upon a time, there was a great deal of talk about the possibility that this would be an opportunity for us to see more and better things in our lives. We had been talking on Facebook all day long with people who were interested in what we could do next or how they might help others find their own way out.” “We’ve always wanted to make sure everyone has access,” he continued; “but it’s not like you can just go into your room at night looking around without seeing
Repetition penalty: 1.5 Generated Text: Once upon a time, there was a man who had been called to the service of God. He came and said: “I am an apostle from Jerusalem.” And he answered him with great joy, saying that it is not possible for me now in this life without having received Jesus Christ as our Lord; but I will be saved through Him alone because my Father has sent Me into all things by His Holy Spirit (John 1). The Christian Church teaches us how much more than any other religion can
Repetition penalty: 2.0 Generated Text: Once upon a time, there was a man who had been sent to the city of Nausicaa by his father. The king’s son and brother were killed in battle at that place; but when he returned with them they found him dead on their way back from war-time.[1] The King gave orders for an expedition against this strange creature called “the Gorgon,” which came out into space during one night after it attacked Earth[2]. It is said that these creatures |
In the code above, temperature is set to 0.3 to emphasize the effect of the repetition penalty. With a low penalty of 1.0, you can see that the model repeats the same phrase over and over again. The model might easily get stuck in loops when the other settings limit the candidate tokens to a small subset. But at a high penalty, such as 2.0 or above, the model strongly avoids repetition, which can sometimes lead to less natural text. A moderate penalty (e.g., 1.2 to 1.5) is often a good compromise to maintain coherence.
After all, the parameters to set in the generate()
function is to keep the text flow naturally. You may want to adjust these parameters by experimentation to see which looks best for your particular application. Note that these parameters may depend on the model you are using, since each model may generate tokens with a different distribution.
Greedy Decoding and Sampling
The do_sample parameter controls whether the model uses sampling (probabilistic selection of tokens) or greedy decoding (always selecting the most probable token). Let’s compare these approaches:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(“gpt2”) model = GPT2LMHeadModel.from_pretrained(“gpt2”)
prompt = “The secret to happiness is” inputs = tokenizer(prompt, return_tensors=“pt”)
# Generate text with greedy decoding vs. sampling print(f“Prompt: {prompt}\n”) print(“Greedy Decoding (do_sample=False):”) output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(“Generated Text:”) print(generated_text) print() print(“Sampling (do_sample=True):”) output = model.generate( **inputs, max_length=100, num_return_sequences=1, temperature=1.0, top_k=50, top_p=1.0, repetition_penalty=1.0, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print(“Generated Text:”) print(generated_text) |
Try running this code multiple times and observing the output. You will notice that the output of greedy decoding is always the same, while the output of sampling is different each time. Greedy decoding is deterministic for a fixed prompt. The model generates a probability distribution, and the most probable token is selected. No randomness is involved. The output is more likely to be repetitive and not useful.
The sampling output is stochastic because the output tokens are selected based on the model’s predicted probability distribution. The randomness allows the model to generate more diverse and creative text while the output is still coherent as long as the other generation parameters are set properly. In the case of sampling output, you can set num_return_sequences
to a number greater than 1 to generate multiple sequences in parallel for the same prompt. This parameter is meaningless for greedy decoding.
Parameters for Specific Applications
For a particular application, what parameter values should you set? There is no concrete answer. You surely need to run some experiments to find the best combinations. But you may use the following as a starting point:
- Factual Generation:
- Lower
temperature
(0.2 to 0.4) for more deterministic output - Moderate
top_p
(0.8 to 0.9) to filter out unlikely tokens - Higher
repetition_penalty
(1.2 to 1.5) to avoid repetitive statements
- Lower
- Creative Writing:
- Higher
temperature
(1.0 to 1.3) for more creative and diverse output - Higher
top_p
(0.9 to 0.95) to allow for more possibilities - Lower
repetition_penalty
(1.0 to 1.1) to allow some stylistic repetition
- Higher
- Code Generation:
- Lower
temperature
(0.1 to 0.3) for more precise and correct code - Lower
top_p
(0.7 to 0.8) to focus on the most likely tokens - Higher
repetition_penalty
(1.3 to 1.5) to avoid redundant code
- Lower
- Dialogue Generation:
- Moderate
temperature
(0.6 to 0.8) for natural but focused responses - Moderate
top_p
(0.9) for a good balance of creativity and coherence - Moderate
repetition_penalty
(1.2) to avoid repetitive phrases
- Moderate
Remember that the language model is not a perfect oracle. It may make mistakes. The above parameters are to help you fit the generation process to the expected style of the output, but not to guarantee the correctness. The output you get may contain errors.
Beam Search and Multiple Sequences Generation
In the above examples, the generation process is **autoregressive**. It is an iterative process that generates one token at a time.
Since each step generates one token through sampling, nothing prevents you from generating multiple tokens at once. If you do that, you will generate multiple output sequences for one input prompt. Theoretically, if you generate $k$ tokens at each step and you set the length to return as $n$, you will generate $k^n$ sequences. This can be a big number, and you may want to limit this to only a few.
The first way to generate multiple sequences is to set num_return_sequences
to a number $k$. You generate $k$ tokens in the first step. Then complete the sequence for each of them. This essentially duplicated the prompt $k$ times in the generation.
The second way is to use beam search. It is a more sophisticated way to generate multiple sequences. It keeps track of the most promising sequences and explores them in parallel. Instead of generating $k^n$ sequences to overwhelm the memory, it keeps only $k$ best sequences at each step. Each token generation step will expand this set temporarily and prune it back to $k$ best sequences.
To use beam search, you need to set num_beams
to a number $k$. Each step will expand each of the $k$ sequences for one more token, resulting $k^2$ sequences, and then select the best $k$ sequences to proceed to the next step. You may also set early_stopping=True
to stop the generation when the end of the sequence is reached. You should also set num_return_sequences
to limit the final selection at the output.
The selection of a sequence is usually based on the cumulative probability of the tokens in the sequence. But you may also skew the selection by other criteria, such as adding a length penalty or avoiding repeating n-grams. Below is an example of using beam search:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(“gpt2”) model = GPT2LMHeadModel.from_pretrained(“gpt2”)
prompt = “The key to successful machine learning is” inputs = tokenizer(prompt, return_tensors=“pt”)
# Generate text with greedy decoding vs. sampling print(f“Prompt: {prompt}\n”) outputs = model.generate( **inputs, num_beams=5, # Number of beams to use early_stopping=True, # Stop when all beams have finished no_repeat_ngram_size=2, # Avoid repeating n-grams num_return_sequences=3, # Return multiple sequences max_length=100, temperature=1.5, do_sample=True, pad_token_id=tokenizer.eos_token_id, ) for idx, output in enumerate(outputs): generated_text = tokenizer.decode(output, skip_special_tokens=True) print(f“Generated Text ({idx+1}):”) print(generated_text) |
You may add more generation parameters (such as length_penalty
) to control the generation process. The example above set a higher temperature to highlight the output of beam search. If you run this code, you may see:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
Prompt: The key to successful machine learning is
Generated Text (1): The key to successful machine learning is to be able to learn from the world around you. It is our job to make sure that we are learning from people, rather than just from machines.
So, let’s take a step back and look at how we can learn. Here’s a list of the tools we use to help us do that. We’re going to go over a few of them here and give you a general idea of what they are and how you can use them to create
Generated Text (2): The key to successful machine learning is to be able to learn from the world around you. It is our job to make sure that we are learning from people, rather than just from machines.
So, let’s take a step back and look at how we can learn. Here’s a list of the tools we use to help us do that. We’re going to go over a few of them here and give you a general idea of what they are and how you can use them and what
Generated Text (3): The key to successful machine learning is to be able to learn from the world around you. It is our job to make sure that we are learning from people, rather than just from machines.
So, let’s take a step back and look at how we can learn. Here’s a list of the tools we use to help us do that. We’re going to go over a few of them here and give you a general idea of what they are and how they work. You can use |
The number of output sequences is still controlled by num_return_sequences,
but the process to generate them uses the beam search algorithm. It is not easy to identify whether beam search is used from the output. One sign is that the output of beam search is not as diverse as just setting num_return_sequences
since many more sequences are generated, and those with higher cumulative probabilities are selected. This filtering indeed reduced the diversity of the output.
Further Readings
Below are some further readings that you may find useful:
Summary
In this post, you see how the many parameters in the generate()
function can be used to control the generation process. You can adjust these parameters to make the output fit the style you would expect for your application. Specifically, you learned:
- How to use temperature to control the probability distribution of the output
- How to use top-k and top-p to control the diversity of the output
- How to control output using repetition penalty, beam search, and greedy decoding
By understanding and tuning these parameters, you can optimize text generation for different applications, from factual writing to creative storytelling, code generation, and dialogue systems.