生成式AI文本密码:Transformer参数全解码​ 原创

发布于 2025-5-9 08:10
浏览
0收藏

本文详细介绍Transformer模型中控制文本生成的关键参数,包括温度、Top-K和Top-P采样、重复惩罚等,并探讨这些参数对生成文本质量的影响及针对不同应用的调整方法。

Transformer模型是当今NLP任务的标准模型。几乎所有NLP任务都涉及文本生成,但文本生成并非模型的直接输出。你可能希望模型能够帮助你生成连贯且与上下文相关的文本。虽然这在一定程度上与模型的质量有关,但生成参数也对生成文本的质量起着至关重要的作用。

在本文中,让我们来一起探索控制Transformer模型中文本生成的关键参数。你将了解这些参数如何影响生成文本的质量,以及如何针对不同的应用进行调整。具体而言,你将学习到:

  • Transformer模型中控制文本生成的核心参数​
  • 不同的解码策略​
  • 如何控制生成文本的创造性和连贯性​
  • 如何针对特定应用微调生成参数​

让我们开始吧!

概述

本文将划分为七个部分进行介绍,它们是:

  • 核心文本生成参数​
  • 温度实验​
  • Top-K和Top-P采样​
  • 控制重复​
  • 贪婪解码和采样​
  • 特定应用的参数​
  • 集束搜索和多序列生成

核心文本生成参数

我们以GPT-2模型为例。它是一个小型Transformer模型,不需要大量计算资源,但仍能生成高质量的文本。使用GPT-2模型生成文本的一个简单示例如下:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
#创建模型和分词器
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
#将输入提示分词为ID序列
prompt = "Artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
# 将输出作为一系列标记ID生成
output = model.generate(
 **inputs,
 max_length=50,
 num_return_sequences=1,
 temperature=1.0,
 top_k=50,
 top_p=1.0,
 repetition_penalty=1.0,
 do_sample=True,
 pad_token_id=tokenizer.eos_token_id,
)
#将标记ID转换为文本字符串
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Prompt: {prompt}")
print("Generated Text:")
print(generated_text)

如果运行此代码,你可能会看到如下输出内容:

Prompt: Artificial intelligence is
Generated Text:
Artificial intelligence is used in the production of technology, the delivery of
which is determined by technological change. For example, an autonomous car can
change its steering wheel to help avoid driving traffic. In the case of artificial
intelligence, this can change what consumers

本例中,你只提供了三个单词的提示,模型就生成了一段很长的文本。这并非一次性生成,而是在迭代过程中多次调用模型。

你可以看到generate()函数中使用的众多参数。你使用的第一个参数是max_length,它控制生成的文本的长度(以标记数量表示)。通常,模型使用提示作为上下文,一次生成一个标记。然后,将新生成的标记附加到提示中并生成下一个标记。因此,你希望生成的文本越长,生成它所需的时间就越长。请注意,这里关注的是标记,而不是单词,因为你在GPT-2模型中使用了子词标记器。一个标记可能只是一个子词单元,而不是一个完整的单词。

然而,该模型并非专门生成任何单个标记。相反,它生成一个“logit”,即下一个标记概率的向量。logit是一个长向量,恰好与词汇表的大小相同。鉴于它是所有可能的“下一个标记”的概率分布,你可以选择概率最高的标记(当设置do_sample=False时),或者任何其他概率非零的标记(当设置do_sample=True时)。这就是所有其他参数的目的。

temperature参数会扭曲概率分布。较低的温度会强调最可能的标记,而较高的温度会缩小可能的标记和不太可能的标记之间的差异。默认温度为1.0,并且应为正值。然后,top_k参数仅选择最靠前的标记标记,而不是整个标记词汇表。然后重新计算概率,总和为1。接下来,如果设置了top_p,则这一组k个标记的集合进一步过滤,保留构成总概率p的那些顶级标记。然后使用这组最终的标记来对下一个标记进行采样,这个过程称为核采样。

请记住,你正在生成一个标记序列,一次一个。你很可能会在每一步中重复看到相同的标记,并且你可能会在序列中看到相同的标记。这通常不是你想要的结果,因此你可能希望在再次看到这些标记时降低其出现的概率。这就是上面repetition_penalty参数的作用所在。

温度实验

假设到目前你已经知道了各个参数的作用,那么接下来,让我们看看当你调整其中一些参数时输出如何变化。

温度参数对生成文本的创造性和随机性有显著的影响。你可以通过以下示例看到其效果:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
# 生成不同温度值的文本
temperatures = [0.2, 0.5, 1.0, 1.5]
print(f"Prompt: {prompt}")
for temp in temperatures:
 print()
 print(f"Temperature: {temp}")
 output = model.generate(
 **inputs,
 max_length=100,
 num_return_sequences=1,
 temperature=temp,
 top_k=50,
 top_p=1.0,
 repetition_penalty=1.0,
 do_sample=True,
 pad_token_id=tokenizer.eos_token_id,
 )
 generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 print("Generated Text:")
 print(generated_text)

如果运行此代码,你可能会看到如下输出内容:

Prompt: The future of artificial intelligence is
Temperature: 0.2
Generated Text:
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future of artificial
intelligence is uncertain.
The future of artificial intelligence is uncertain. The future
Temperature: 0.5
Generated Text:
The future of artificial intelligence is uncertain.
"There is a lot of work to be done on this," said Eric Schmitt, a professor
of computer science and engineering at the University of California, Berkeley.
"We're looking for a way to make AI more like computers. We need to take a step
back and look at how we think about it and how we interact with it."
Schmitt said he's confident that artificial intelligence will eventually be
able to do more than
Temperature: 1.0
Generated Text:
The future of artificial intelligence is not yet clear, however."
"Is the process that we are trying to do through computer vision and the ability to
look at a person at multiple points without any loss of intelligence due to not
seeing a person at multiple points?" asked Richard. "I also think the people who
are doing this research are extremely interesting to me due to being able to see
humans at a range of different points in time. In particular, they've shown how
to do a pretty complex
Temperature: 1.5
Generated Text:
The future of artificial intelligence is an era to remember as much as Google in
search results, particularly ones not supported by much else for some years -- and
it might look like the search giant is now just as good without artificial
intelligence. [Graphic image from Shutterstock]

当温度较低(例如0.2)时,文本会变得更加集中和确定,通常会坚持使用常用短语和传统观点。你还会看到,由于概率集中在少数几个标记上,文本会不断重复相同的句子,从而限制了多样性。这个问题可以通过使用重复惩罚参数来解决,该参数将在下一节中介绍。

中等温度(例如0.5到1.0)的文本在连贯性和创造性之间取得了良好的平衡。生成的文本可能并非基于事实,但语言自然。

当温度较高(例如1.5)时,文本会变得更加随意和富有创意,但也可能变得缺乏连贯性,有时甚至缺乏逻辑性。语言可能难以理解,就像上面的例子一样。

选择合适的温度取决于你的应用。如果你正在创建代码补全或写作助手,通常较低的温度更佳。对于创意写作或头脑风暴,较高的温度可以产生更多样化、更有趣的结果。

Top-K和Top-P采样

核采样参数控制着模型选择下一个标记的灵活性。你应该调整top_k参数还是top_p参数?让我们通过一个例子来看一下它们的效果:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The best way to learn programming is"
inputs = tokenizer(prompt, return_tensors="pt")
#使用不同top_k值生成文本
top_k_values = [5, 20, 50]
print(f"Prompt: {prompt}")
for top_k in top_k_values:
 print()
 print(f"Top-K = {top_k}")
 output = model.generate(
 **inputs,
 max_length=100,
 num_return_sequences=1,
 temperature=1.0,
 top_k=top_k,
 top_p=1.0,
 repetition_penalty=1.0,
 do_sample=True,
 pad_token_id=tokenizer.eos_token_id,
 )
 generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 print("Generated Text:")
 print(generated_text)
# 使用不同top_p值生成文本
top_p_values = [0.5, 0.7, 0.9]
for top_p in top_p_values:
 print()
 print(f"Top-P = {top_p}")
 output = model.generate(
 **inputs,
 max_length=100,
 num_return_sequences=1,
 temperature=1.0,
 top_k=0,
 top_p=top_p,
 repetition_penalty=1.0,
 do_sample=True,
 pad_token_id=tokenizer.eos_token_id,
 )
 generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 print("Generated Text:")
 print(generated_text)
如果运行此代码,你可能会看到如下输出内容:
Prompt: The best way to learn programming is
Top-K = 5
Generated Text:
The best way to learn programming is to be able to learn the basics in a very short
amount of time, and then learn to use them effectively and quickly.
If you want to be a successful programmer in this way, you should learn to use the
techniques in the above video to learn the basics of programming.
If you want to learn to code more effectively, you can also get more experienced
programmers by doing the following:
Learning to Code
Learning to code is very
Top-K = 20
Generated Text:
The best way to learn programming is to learn it.
In order to get started with Ruby you're going to have to make a few mistakes, some
of them can be fairly obvious.
First of all, you're going to have to write a function that takes in a value. What
this means is that you're going to make a new instance of the Ruby function. You can
read more about this in Part 1 of this course, or just try it out from the REPL.
Top-K = 50
Generated Text:
The best way to learn programming is to become familiar with the language and the
software. One of the first and most common forms of programming is to create,
modify, and distribute code.
However, there are very few programming libraries that can provide us with all
that we need.
The following sample programming program uses some of the above, but does not show
the best way to learn programming. It was written in Java and in C or C++.
The original source code is
Top-P = 0.5
Generated Text:
The best way to learn programming is to be able to create a tool for you. That's
what I do.
That's why I'm here today.
I'm here to talk about the basics of programming, and I'm going to tell you how to
learn programming.
I'm here to talk about learning programming.
It's easy to forget that you don't have to know how to program. It's easy to forget
that you don't have to know how
Top-P = 0.7
Generated Text:
The best way to learn programming is to practice programming. Learn the principles
of programming by observing and performing exercises.
I used to work in a world of knowledge which included all sorts of things, and was
able to catch up on them and understand them from their perspective. For instance, I
learned to sit up straight and do five squats. Then, I would have to practice some
type of overhead training. I would try to learn the best technique and add that to
my repertoire.
What
Top-P = 0.9
Generated Text:
The best way to learn programming is to become a good hacker. Don't use any
programming tools. Just a regular dot-com user, an occasional coding learner, and
stick with it.
— Victoria E. Nichols

你可以通过一个小的k值,例如5,看到模型可供选择的选项较少,从而导致文本更可预测。在极端情况下,当k=1时,模型总是选择概率最高的单个标记,这是贪婪解码,通常会产生较差的输出。当使用一个较大的k值,比如50,模型就有更多的选项可以选择,从而产生更加多样化的文本。

类似地,对于top_p参数,较小的p值意味着模型从一组较小的高概率标记中进行选择,从而产生更有针对性的文本。使用较大的p值,例如0.9,则模型的选择范围更广,可能会产生更多样化的文本。但是,对于给定的文本,你可以选择多少个选项并非固定不变,它取决于模型预测的概率分布。当模型对下一个标记非常有信心时(例如受某些语法规则限制),只允许使用非常小的标记集合。这种自适应特性也是为什么top-p采样通常比top-k采样更受欢迎的原因。

控制重复

重复是文本生成中常见的问题。repetition_penalty参数通过惩罚已在生成文本中出现过的标记来帮助解决这个问题。让我们看看它是如何工作的:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "Once upon a time, there was a"
inputs = tokenizer(prompt, return_tensors="pt")
# 使用不同的重复惩罚生成文本
penalties = [1.0, 1.2, 1.5, 2.0]
print(f"Prompt: {prompt}")
for penalty in penalties:
 print()
 print(f"Repetition penalty: {penalty}")
 output = model.generate(
 **inputs,
 max_length=100,
 num_return_sequences=1,
 temperature=0.3,
 top_k=50,
 top_p=1.0,
 repetition_penalty=penalty,
 do_sample=True,
 pad_token_id=tokenizer.eos_token_id,
 )
 generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 print("Generated Text:")
 print(generated_text)

如果运行此代码,你可能会看到如下输出内容:

Prompt: Once upon a time, there was a
Repetition penalty: 1.0
Generated Text:
Once upon a time, there was a great deal of confusion about what was going on. The
first thing that came to mind was the fact that the government had already been in
place for a long time, and that the government had been in place for a long time.
And it was clear that the government had been in place for a long time. And it was
clear that the government had been in place for a long time. And it was clear that
the government had been in place for a long
Repetition penalty: 1.2
Generated Text:
Once upon a time, there was a great deal of talk about the possibility that this
would be an opportunity for us to see more and better things in our lives. We had
been talking on Facebook all day long with people who were interested in what we
could do next or how they might help others find their own way out."
"We've always wanted to make sure everyone has access," he continued; "but it's not
like you can just go into your room at night looking around without seeing
Repetition penalty: 1.5
Generated Text:
Once upon a time, there was a man who had been called to the service of God. He
came and said: "I am an apostle from Jerusalem." And he answered him with great joy,
saying that it is not possible for me now in this life without having received
Jesus Christ as our Lord; but I will be saved through Him alone because my Father
has sent Me into all things by His Holy Spirit (John 1).
The Christian Church teaches us how much more than any other religion can
Repetition penalty: 2.0
Generated Text:
Once upon a time, there was a man who had been sent to the city of Nausicaa by his
father. The king's son and brother were killed in battle at that place; but when
he returned with them they found him dead on their way back from war-time.[1]
The King gave orders for an expedition against this strange creature called "the
Gorgon," which came out into space during one night after it attacked Earth[2]. It
is said that these creatures

在上面的代码中,为了强调重复惩罚的效果,我们将温度设置为0.3。当惩罚值较低(例如1.0)时,你可以看到模型一遍又一遍地重复同一个短语。当其他设置将候选标记限制在较小的子集时,模型很容易陷入循环。但是,当惩罚值较高(例如2.0或更高)时,模型会强烈避免重复,这有时会导致文本的自然性降低。中等惩罚值(例如1.2到1.5)通常是保持连贯性的良好折衷方案。

毕竟,generate()函数中设置的参数是为了保持文本自然流畅。你可能需要通过实验来调整这些参数,以找到最适合你特定应用的参数。请注意,这些参数可能取决于你使用的模型,因为每个模型生成的标记可能具有不同的分布。

贪婪解码和采样

do_sample参数控制模型是使用采样(基于概率选择标记)还是贪婪解码(始终选择最可能的标记)。让我们比较一下这两种方法:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
prompt = "The secret to happiness is"
inputs = tokenizer(prompt, return_tensors="pt")
# 使用贪婪解码与采样生成文本
print(f"Prompt: {prompt}\n")
print("Greedy Decoding (do_sample=False):")
output = model.generate(
 **inputs,
 max_length=100,
 num_return_sequences=1,
 temperature=1.0,
 top_k=50,
 top_p=1.0,
 repetition_penalty=1.0,
 do_sample=False,
 pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)
print()
print("Sampling (do_sample=True):")
output = model.generate(
 **inputs,
 max_length=100,
 num_return_sequences=1,
 temperature=1.0,
 top_k=50,
 top_p=1.0,
 repetition_penalty=1.0,
 do_sample=True,
 pad_token_id=tokenizer.eos_token_id,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("Generated Text:")
print(generated_text)

尝试多次运行此代码并观察输出结果。你会注意到,贪婪解码的输出始终相同,而采样的输出每次都不同。对于固定的提示,贪婪解码是确定性的。该模型生成概率分布,并选择最可能的标记,不涉及随机性,输出更有可能重复且无用。

采样输出是随机的,因为输出标记是根据模型预测的概率分布选择的。这种随机性使模型能够生成更加多样化和富有创意的文本;同时,只要其他生成参数设置得当,输出仍然保持一致。在采样输出的情况下,你可以将num_return_sequences设置为大于1的数字,以便为同一提示并行生成多个序列。此参数对于贪婪解码毫无意义。

特定应用的参数

对于特定的应用,应该设置哪些参数值?并没有明确的答案。你肯定需要进行一些实验来找到最佳组合。但是,你可以参考以下建议:

  • 事实生成:

​A.提供更低的temperature参数值(0.2至0.4)以获得更确定的输出
B.使用中等大小的top_p参数值(0.8到0.9),过滤掉不太可能的标记​
C.使用更高的repetition_penalty参数值(1.2至1.5),以避免重复陈述​

  • 创意写作:

​A.提供更高一些的temperature参数值(1.0到1.3),可实现更具创意和多样化的输出

B.提供更高的top_p参数值(0.9到0.95),以提供更多可能性
C.提供较低的repetition_penalty参数值(1.0到1.1),以允许一些风格重复​

  • 代码生成:

​A.提供更低的temperature参数值(0.1到0.3),可获得更精确、更正确的代码
B.提供较低的top_p参数值(0.7至0.8),以关注最可能的标记​
C.提供更高的repetition_penalty参数值(1.3到1.5),以避免冗余代码​

  • 对话生成:​

A.提供中等大小的temperature参数值(0.6至0.8),反应自然但集中

B.提供中等大小的top_p参数值(0.9),创造力和连贯性达到良好平衡

C.提供中等大小的repetition_penalty参数值(1.2),避免重复的短语

请记住,语言模型并非完美的预言机,它也可能会出错。上述参数旨在帮助你将生成过程与预期的输出风格相匹配,但并不能保证其正确性。你得到的输出可能包含错误。

集束搜索和多序列生成

在上面的例子中,生成过程是自回归的。它是一个迭代过程,每次生成一个标记。

由于每个步骤都会通过采样生成一个标记,因此你可以同时生成多个标记。这样一来,你将为一个输入提示生成多个输出序列。理论上,如果你每一步生成k个标记,并且设置返回的长度为n,你将生成kn个序列。这个数字可能很大,你可能希望将其限制为几个。

生成多个序列的第一种方法是设置num_return_sequences为数字k。你在第一步中生成k个标记。然后,完成每个标记的序列。这基本上确定了在生成中复制了提示k次。

第二种方法是使用集束搜索。这是一种生成多个序列的更复杂的方法。它会跟踪最有希望的序列并并行探索它们。它不是生成kn个序列以淹没记忆,它只保留每一步的最佳序列。每个标记生成步骤都会暂时扩展这个集合,然后将其修剪回最佳序列。

要使用集束搜索,你需要设置num_beams为一个数字k。每一步都会扩大k个序列以再添加一个标记,结果生成k2个序列,然后选择最佳k个序列继续下一步。你还可以通过设置early_stopping=True,以便在到达序列末尾时停止生成。你还应该设置num_return_sequences在输出时限制最终选择。

序列的选择通常基于序列中标记的累积概率。但你也可以通过其他标准来调整选择,例如添加长度惩罚或避免重复n-grams。以下是使用集束搜索的示例:

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

prompt = "The key to successful machine learning is"
inputs = tokenizer(prompt, return_tensors="pt")

#使用贪婪解码与采样生成文本
print(f"Prompt: {prompt}\n")
outputs = model.generate(
 **inputs,
 num_beams=5, # 要使用的光束数量
 early_stopping=True, # 当所有光束都完成时停止
 no_repeat_ngram_size=2, # 避免重复n-gram
 num_return_sequences=3, # 返回多个序列
 max_length=100,
 temperature=1.5,
 do_sample=True,
 pad_token_id=tokenizer.eos_token_id,
)
for idx, output in enumerate(outputs):
 generated_text = tokenizer.decode(output, skip_special_tokens=True)
 print(f"Generated Text ({idx+1}):")
 print(generated_text)

你可以添加更多生成参数(例如length_penalty)来控制生成过程。上面的示例设置了更高的温度,以突出集束搜索的输出。运行此代码,你可能会看到:

Prompt: The key to successful machine learning is

Generated Text (1):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.

So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how you can use them to create

Generated Text (2):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.

So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how you can use them and what

Generated Text (3):
The key to successful machine learning is to be able to learn from the world around
you. It is our job to make sure that we are learning from people, rather than just
from machines.

So, let's take a step back and look at how we can learn. Here's a list of the tools
we use to help us do that. We're going to go over a few of them here and give you
a general idea of what they are and how they work. You can use

输出序列的数量仍然受num_return_sequences控制,但生成序列的过程使用了集束搜索算法。不过,从输出结果很难判断是否使用了集束搜索。一个迹象是,集束搜索的输出不像单纯的设置num_return_sequences那样具有多样性,因为生成的序列更多并且选择了累积概率更高的序列。这种过滤确实降低了输出的多样性。

进一步阅读

以下是一些你可能觉得有用的补充阅读材料:

总结

在本文中,你了解了如何使用generate()函数中的众多参数来控制生成过程。你可以调整这些参数,使输出符合你应用程序的预期样式。具体来说,你学习了:

  • 如何利用温度来控制输出的概率分布​
  • 如何使用top-k和top-p来控制输出的多样性​
  • 如何使用重复惩罚、集束搜索和贪婪解码来控制输出​

通过理解和调整这些参数,你可以优化不同应用的文本生成,从事实写作到创意叙事、代码生成和对话系统等各个领域。

译者介绍

朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。

原文标题:​Understanding Text Generation Parameters in Transformers​,作者:Muhammad Asad Iqbal Khan

©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
已于2025-5-9 08:16:07修改
收藏
回复
举报
回复
相关推荐