对大模型输出的 logits 进行处理,从而控制文本的生成

对大模型输出的 logits 进行处理,从而控制文本的生成

flyfish

在文本生成任务中,模型输出的 logits 代表了每个词被选为下一个生成词的未归一化概率得分。通过对 logits 进行处理,可以精确地控制文本的生成

基本原理

在每一步生成过程中,模型会输出一个 logits 向量,其长度等于词汇表的大小,每个元素对应词汇表中一个词的得分。通常,会对 logits 应用 softmax 函数将其转换为概率分布,然后根据这个概率分布来选择下一个生成的词。而 logits_processor 就是在应用 softmax 函数之前,对 logits 进行修改,从而改变最终的概率分布和词的选择。

具体控制方式

1. 避免重复
  • 重复惩罚(RepetitionPenaltyLogitsProcessor
    • 机制:对于已经在生成文本中出现过的词,降低其 logits 的值。具体来说,会将这些词的 logits 除以一个大于 1 的惩罚系数,使得它们在后续生成中被选中的概率降低。
    • 示例:假设生成的文本中已经出现了“苹果”这个词,当模型再次预测下一个词时,“苹果”对应的 logits 会被惩罚,从而减少再次生成“苹果”的可能性,避免文本中出现过多重复内容。
  • 不重复 n - gram(NoRepeatNGramLogitsProcessor
    • 机制:检查生成的文本中是否已经存在某个 n - gram(连续的 n 个词),如果存在,则将可能导致该 n - gram 重复出现的词的 logits 设为负无穷。这样,在后续的概率计算中,这些词的概率会变为 0,不会被选中。
    • 示例:如果 n = 2,当前生成的文本是“我 喜欢”,那么在选择下一个词时,会避免选择那些会导致“我 喜欢”这个 2 - gram 重复出现的词,如“我”或“喜欢”,从而提高文本的多样性。
2. 控制生成长度
  • 最小长度限制(MinLengthLogitsProcessor
    • 机制:在生成的文本长度未达到指定的最小长度之前,将结束标记(EOS)的 logits 设为负无穷。这样,在 softmax 处理后,结束标记的概率会变为 0,模型不会选择结束生成,确保文本达到一定的长度。
    • 示例:如果设置最小长度为 10 个词,在生成的词数小于 10 时,结束标记的 logits 始终为负无穷,模型会继续生成,直到达到最小长度要求。
  • 最小新标记数(MinNewTokensLengthLogitsProcessor
    • 机制:类似于最小长度限制,不过是针对新生成的标记数量。在新生成的标记数未达到指定数量之前,降低结束标记的 logits,保证生成足够数量的新内容。
3. 采样策略调整
  • 温度调整(TemperatureLogitsWarper
    • 机制:将 logits 除以一个温度参数 temperature。温度越高,logits 之间的差异会被缩小,经过 softmax 处理后,概率分布会更加均匀,采样会更随机;温度越低,logits 之间的差异会被放大,概率分布会更集中,更倾向于选择概率最大的词。
    • 示例:当 temperature = 1 时,保持原始的 logits 分布;当 temperature > 1 时,模型可能会生成一些更具创意但可能不太准确的文本;当 temperature < 1 时,模型会更保守,生成的文本更符合常见的表达。
  • Top - k 采样(TopKLogitsWarper
    • 机制:只保留 logits 中概率最高的 k 个词,将其余词的 logits 设为负无穷。这样,在后续的采样中,只会从这 k 个词中选择下一个生成的词,限制了采样范围,提高了生成的稳定性。
    • 示例:如果 k = 5,模型会在每次生成时,只考虑概率最高的 5 个词,排除其他词的干扰。
  • Top - p 采样(TopPLogitsWarper
    • 机制:选择累积概率达到 p 的最小词集合,只保留这些词的 logits,其余词的 logits 设为负无穷。这种方法结合了概率和词的数量,既能控制采样范围,又能适应不同的概率分布。
    • 示例:如果 p = 0.9,模型会选择累积概率达到 0.9 的最小词集合,从这个集合中进行采样。
4. 约束生成内容
  • 禁用词过滤(NoBadWordsLogitsProcessor
    • 机制:将禁用词的 logits 设为负无穷,使得这些词在后续的概率计算中概率为 0,不会被选中,从而避免生成包含禁用词的文本。
    • 示例:如果禁用词列表中包含“脏话”,那么在生成过程中,“脏话”对应的 logits 会被设为负无穷,不会出现在生成的文本中。
  • 前缀约束(PrefixConstrainedLogitsProcessor
    • 机制:根据给定的前缀允许标记函数,限制生成的词必须符合特定的前缀约束。不符合约束的词的 logits 会被设为负无穷,从而保证生成的文本符合特定的前缀要求。
    • 示例:如果要求生成的文本必须以“今天”开头,那么在生成第一个词时,只有与“今天”相关的词的 logits 会被保留,其他词的 logits 会被设为负无穷。

配置参数

参数数据类型默认值含义
guidance_scalefloatNone引导比例,用于无批量分类器自由引导,值不为 1 时会添加相应的 logits 处理器,影响生成过程的引导程度。
sequence_bias-None序列偏差,用于控制特定序列的生成概率,设置后会添加序列偏差 logits 处理器。
diversity_penaltyfloatNone多样性惩罚,大于 0 时会添加汉明多样性 logits 处理器,鼓励生成结果更具多样性。
encoder_repetition_penaltyfloatNone编码器重复惩罚,不为 1 且编码器输入 ID 形状符合要求时,会添加编码器重复惩罚 logits 处理器,减少编码器输入相关的重复内容。
repetition_penaltyfloatNone重复惩罚,不为 1 时会添加重复惩罚 logits 处理器,防止生成结果出现过多重复。
no_repeat_ngram_sizeintNone不重复 n - gram 大小,大于 0 时会添加不重复 n - gram logits 处理器,避免生成的文本中出现重复的 n - gram 片段。
encoder_no_repeat_ngram_sizeintNone编码器不重复 n - gram 大小,大于 0 且编码器输入 ID 形状符合要求时,会添加编码器不重复 n - gram logits 处理器,减少编码器输入相关的重复 n - gram 内容。
bad_words_ids-None禁用词 ID,设置后会添加禁用词 logits 处理器,防止生成包含指定禁用词的文本。
min_lengthintNone最小长度,大于 0 且有结束标记张量时,会添加最小长度 logits 处理器,确保生成的文本达到最小长度要求。
min_new_tokensintNone最小新标记数,大于 0 且有结束标记张量时,会添加最小新标记长度 logits 处理器,保证生成的新标记数量达到要求。
forced_bos_token_idintNone强制起始标记 ID,设置后会添加强制起始标记 logits 处理器,确保生成的文本以指定的标记开始。
forced_eos_token_idintNone强制结束标记 ID,设置后会添加强制结束标记 logits 处理器,使生成的文本在达到指定标记时结束。
remove_invalid_valuesboolFalse是否移除无效值,为 True 时会添加移除无穷大和 NaN 值的 logits 处理器,保证生成过程中 logits 的有效性。
exponential_decay_length_penalty-None指数衰减长度惩罚,设置后会添加指数衰减长度惩罚处理器,对生成文本的长度进行惩罚控制。
suppress_tokens-None抑制标记,设置后会添加抑制标记 logits 处理器,降低指定标记的生成概率。
begin_suppress_tokens-None起始抑制标记,设置后会添加起始抑制标记 logits 处理器,在生成的起始阶段抑制指定标记的生成。
forced_decoder_ids-None强制解码器 ID,不建议使用,设置后会抛出异常,提示使用 input_idsdecoder_input_ids 代替。
do_sampleboolFalse是否使用采样策略,为 True 时会根据其他采样相关参数添加相应的 logits 调整器。
temperaturefloatNone采样温度,不为 1 时会添加温度 logits 调整器,控制采样的随机性,值越大随机性越强。
top_kintNonetop - k 采样值,不为 0 时会添加 top - k logits 调整器,只考虑概率最高的 k 个标记进行采样。
top_pfloatNonetop - p 采样值,小于 1 时会添加 top - p logits 调整器,只考虑累积概率达到 p 的标记进行采样。
min_pfloatNone最小概率阈值,设置后会添加最小概率 logits 调整器,在温度缩放后应用,控制采样的最小概率。
typical_pfloatNone典型概率采样值,小于 1 时会添加典型概率 logits 调整器,基于典型概率进行采样。
epsilon_cutofffloatNoneepsilon 截断值,在 0 到 1 之间时会添加 epsilon logits 调整器,用于截断低概率标记。
eta_cutofffloatNoneeta 截断值,在 0 到 1 之间时会添加 eta logits 调整器,结合设备信息对低概率标记进行截断。
watermarking_config-None水印配置,设置后会添加水印处理器,在生成的文本中添加水印。
renormalize_logitsboolFalse是否重新归一化 logits,为 True 时会添加 logit 归一化处理器,确保 logits 归一化。

logits 说明

logits 是模型在进行分类或预测任务时,最后一层神经元的原始输出值,它是未经过归一化处理的数值。在文本生成场景中,logits 代表了模型预测词汇表中每个词作为下一个生成词的得分,这些得分反映了模型对每个词成为下一个词的相对可能性判断,但并非是概率值。

数学公式

1. 线性变换得到 logits

在许多深度学习模型中,logits 通常是通过对前一层的输出进行线性变换得到的。假设模型前一层的输出为向量 h \mathbf{h} h,权重矩阵为 W \mathbf{W} W,偏置向量为 b \mathbf{b} b,则 logits 向量 z \mathbf{z} z 的计算公式如下:

z = W h + b \mathbf{z} = \mathbf{W}\mathbf{h} + \mathbf{b} z=Wh+b

其中, h \mathbf{h} h 是前一层输出的特征向量,维度通常为 d h d_h dh W \mathbf{W} W 是权重矩阵,维度为 V × d h V \times d_h V×dh V V V 是词汇表的大小; b \mathbf{b} b 是偏置向量,维度为 V V V z \mathbf{z} zlogits 向量,维度为 V V V,每个元素 z i z_i zi 对应词汇表中第 i i i 个词的得分。

2. logits 转换为概率分布

为了将 logits 转换为概率分布,通常会使用 softmax 函数。softmax 函数可以将 logits 向量中的每个元素转换为一个在 [ 0 , 1 ] [0, 1] [0,1] 范围内的值,且所有元素之和为 1,符合概率分布的定义。softmax 函数的数学公式如下:

P ( y i ) = e z i ∑ j = 1 V e z j P(y_i) = \frac{e^{z_i}}{\sum_{j=1}^{V} e^{z_j}} P(yi)=j=1Vezjezi

其中, P ( y i ) P(y_i) P(yi) 是词汇表中第 i i i 个词被选为下一个生成词的概率, z i z_i zilogits 向量中第 i i i 个元素的值, V V V 是词汇表的大小。

示例

假设词汇表大小 V = 3 V = 3 V=3,模型输出的 logits 向量为 z = [ 2 , 1 , 3 ] \mathbf{z} = [2, 1, 3] z=[2,1,3],下面计算经过 softmax 函数处理后的概率分布:

首先,计算分母的值:

∑ j = 1 3 e z j = e 2 + e 1 + e 3 ≈ 7.389 + 2.718 + 20.086 = 30.193 \sum_{j=1}^{3} e^{z_j} = e^2 + e^1 + e^3 \approx 7.389 + 2.718 + 20.086 = 30.193 j=13ezj=e2+e1+e37.389+2.718+20.086=30.193

然后,分别计算每个词的概率:

P ( y 1 ) = e 2 30.193 ≈ 7.389 30.193 ≈ 0.245 P(y_1) = \frac{e^2}{30.193} \approx \frac{7.389}{30.193} \approx 0.245 P(y1)=30.193e230.1937.3890.245

P ( y 2 ) = e 1 30.193 ≈ 2.718 30.193 ≈ 0.090 P(y_2) = \frac{e^1}{30.193} \approx \frac{2.718}{30.193} \approx 0.090 P(y2)=30.193e130.1932.7180.090

P ( y 3 ) = e 3 30.193 ≈ 20.086 30.193 ≈ 0.665 P(y_3) = \frac{e^3}{30.193} \approx \frac{20.086}{30.193} \approx 0.665 P(y3)=30.193e330.19320.0860.665

可以看到,经过 softmax 函数处理后,得到了一个概率分布 [ 0.245 , 0.090 , 0.665 ] [0.245, 0.090, 0.665] [0.245,0.090,0.665],表示词汇表中三个词被选为下一个生成词的概率。

在模型中的作用

在文本生成任务中,模型会根据 logits 转换后的概率分布来选择下一个生成的词。常见的选择方法有贪心搜索(选择概率最大的词)、采样搜索(根据概率分布随机采样)等。同时,logits_processor 会对 logits 进行调整,从而影响最终的概率分布和词的选择,以控制文本生成的行为和质量。

代码说明

logits_processor 是 _get_logits_processor 方法的一个参数,它是一个可选的 LogitsProcessorList 对象。这个方法会根据 GenerationConfig 中的各种配置参数,创建一系列不同的 LogitsProcessor 实例,并将它们添加到 processors 列表中。最后,如果传入了 logits_processor,还会将其与新创建的处理器列表进行合并。

def _get_logits_processor(
        self,
        generation_config: GenerationConfig,
        input_ids_seq_length: int,
        encoder_input_ids: torch.LongTensor,
        prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
        logits_processor: Optional[LogitsProcessorList],
        device: str = None,
        model_kwargs: Optional[Dict[str, Any]] = None,
        negative_prompt_ids: Optional[torch.Tensor] = None,
        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
    ) -> LogitsProcessorList:
        """
        此函数返回一个 `LogitsProcessorList` 对象,该对象包含所有用于修改语言模型头部得分的相关 `LogitsProcessor` 实例。
        这些处理器会对模型预测的 logits 进行调整,以控制文本生成的行为,例如避免重复、控制生成长度等。

        参数:
            generation_config (GenerationConfig): 生成配置对象,包含了文本生成过程中的各种配置参数。
            input_ids_seq_length (int): 输入 ID 序列的长度。
            encoder_input_ids (torch.LongTensor): 编码器的输入 ID。
            prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): 一个可调用对象,用于指定允许的前缀标记。
            logits_processor (Optional[LogitsProcessorList]): 可选的 logits 处理器列表。
            device (str, optional): 设备名称,如 'cuda' 或 'cpu'。默认为 None。
            model_kwargs (Optional[Dict[str, Any]], optional): 模型的其他关键字参数。默认为 None。
            negative_prompt_ids (Optional[torch.Tensor], optional): 负提示的 ID。默认为 None。
            negative_prompt_attention_mask (Optional[torch.Tensor], optional): 负提示的注意力掩码。默认为 None。

        返回:
            LogitsProcessorList: 包含所有 logits 处理器的列表。
        """
        # 实例化一个空的处理器列表
        processors = LogitsProcessorList()

        # 如果配置了引导比例且不为 1,则添加无批量分类器自由引导 logits 处理器
        if generation_config.guidance_scale is not None and generation_config.guidance_scale != 1:
            processors.append(
                UnbatchedClassifierFreeGuidanceLogitsProcessor(
                    generation_config.guidance_scale,
                    self,
                    unconditional_ids=negative_prompt_ids,
                    unconditional_attention_mask=negative_prompt_attention_mask,
                    use_cache=generation_config.use_cache,
                )
            )
        # 如果配置了序列偏差,则添加序列偏差 logits 处理器
        if generation_config.sequence_bias is not None:
            processors.append(SequenceBiasLogitsProcessor(sequence_bias=generation_config.sequence_bias))

        # 如果配置了多样性惩罚且大于 0,则添加汉明多样性 logits 处理器
        if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
            processors.append(
                HammingDiversityLogitsProcessor(
                    diversity_penalty=generation_config.diversity_penalty,
                    num_beams=generation_config.num_beams,
                    num_beam_groups=generation_config.num_beam_groups,
                )
            )
        # 如果配置了编码器重复惩罚且不为 1,并且编码器输入 ID 的形状为二维,则添加编码器重复惩罚 logits 处理器
        if (
            generation_config.encoder_repetition_penalty is not None
            and generation_config.encoder_repetition_penalty != 1.0
        ):
            if len(encoder_input_ids.shape) == 2:
                processors.append(
                    EncoderRepetitionPenaltyLogitsProcessor(
                        penalty=generation_config.encoder_repetition_penalty,
                        encoder_input_ids=encoder_input_ids,
                    )
                )
            else:
                # 如果编码器输入 ID 形状不符合要求,发出警告
                warnings.warn(
                    "Passing `encoder_repetition_penalty` requires some form of `input_ids` to be passed to "
                    "`generate`, ignoring the argument.",
                    UserWarning,
                )
        # 如果配置了重复惩罚且不为 1,则添加重复惩罚 logits 处理器
        if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
        # 如果配置了不重复 n-gram 大小且大于 0,则添加不重复 n-gram logits 处理器
        if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
            processors.append(NoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
        # 如果配置了编码器不重复 n-gram 大小且大于 0,并且编码器输入 ID 的形状为二维,则添加编码器不重复 n-gram logits 处理器
        if (
            generation_config.encoder_no_repeat_ngram_size is not None
            and generation_config.encoder_no_repeat_ngram_size > 0
        ):
            if len(encoder_input_ids.shape) == 2:
                processors.append(
                    EncoderNoRepeatNGramLogitsProcessor(
                        generation_config.encoder_no_repeat_ngram_size,
                        encoder_input_ids,
                    )
                )
            else:
                # 如果编码器输入 ID 形状不符合要求,发出警告
                warnings.warn(
                    "Passing `encoder_no_repeat_ngram_size` requires some form of `input_ids` to be passed to "
                    "`generate`, ignoring the argument.",
                    UserWarning,
                )
        # 如果配置了禁用词 ID,则添加禁用词 logits 处理器
        if generation_config.bad_words_ids is not None:
            processors.append(
                NoBadWordsLogitsProcessor(
                    generation_config.bad_words_ids,
                    generation_config._eos_token_tensor,
                )
            )
        # 如果配置了最小长度且大于 0,并且有结束标记张量,则添加最小长度 logits 处理器
        if (
            generation_config.min_length is not None
            and generation_config._eos_token_tensor is not None
            and generation_config.min_length > 0
        ):
            processors.append(
                MinLengthLogitsProcessor(
                    generation_config.min_length,
                    generation_config._eos_token_tensor,
                    device=device,
                )
            )
        # 如果配置了最小新标记数且大于 0,并且有结束标记张量,则添加最小新标记长度 logits 处理器
        if (
            generation_config.min_new_tokens is not None
            and generation_config._eos_token_tensor is not None
            and generation_config.min_new_tokens > 0
        ):
            processors.append(
                MinNewTokensLengthLogitsProcessor(
                    input_ids_seq_length,
                    generation_config.min_new_tokens,
                    generation_config._eos_token_tensor,
                    device=device,
                )
            )
        # 如果提供了前缀允许标记函数,则添加前缀约束 logits 处理器
        if prefix_allowed_tokens_fn is not None:
            processors.append(
                PrefixConstrainedLogitsProcessor(
                    prefix_allowed_tokens_fn,
                    generation_config.num_beams // generation_config.num_beam_groups,
                )
            )
        # 如果配置了强制起始标记 ID,则添加强制起始标记 logits 处理器
        if generation_config.forced_bos_token_id is not None:
            processors.append(
                ForcedBOSTokenLogitsProcessor(
                    generation_config.forced_bos_token_id,
                )
            )
        # 如果配置了强制结束标记 ID,则添加强制结束标记 logits 处理器
        if generation_config.forced_eos_token_id is not None:
            processors.append(
                ForcedEOSTokenLogitsProcessor(
                    generation_config.max_length,
                    generation_config.forced_eos_token_id,
                    device=device,
                )
            )
        # 如果配置了移除无效值,则添加移除无穷大和 NaN 值的 logits 处理器
        if generation_config.remove_invalid_values is True:
            processors.append(InfNanRemoveLogitsProcessor())
        # 如果配置了指数衰减长度惩罚,则添加指数衰减长度惩罚处理器
        if generation_config.exponential_decay_length_penalty is not None:
            processors.append(
                ExponentialDecayLengthPenalty(
                    generation_config.exponential_decay_length_penalty,
                    generation_config._eos_token_tensor,
                    input_ids_seq_length,
                )
            )
        # 如果配置了抑制标记,则添加抑制标记 logits 处理器
        if generation_config.suppress_tokens is not None:
            processors.append(
                SuppressTokensLogitsProcessor(
                    generation_config.suppress_tokens,
                    device=device,
                )
            )
        # 如果配置了起始抑制标记,则添加起始抑制标记 logits 处理器
        if generation_config.begin_suppress_tokens is not None:
            begin_index = input_ids_seq_length
            begin_index = (
                begin_index
                if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
                else begin_index + 1
            )
            processors.append(
                SuppressTokensAtBeginLogitsProcessor(
                    generation_config.begin_suppress_tokens,
                    begin_index,
                    device=device,
                )
            )
        # 如果配置了强制解码器 ID,则抛出异常,提示使用 input_ids 或 decoder_input_ids 代替
        if generation_config.forced_decoder_ids is not None:
            # TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
            raise ValueError(
                "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
                "in favour of `input_ids` or `decoder_input_ids` respectively.",
            )

        # 合并自定义的 logits 处理器列表
        processors = self._merge_criteria_processor_list(processors, logits_processor)

        # 以下处理器之前被称为 `LogitsWarpers`,仅在采样策略下应用
        if generation_config.do_sample:
            # 在束搜索方法中,我们需要至少保留一个非结束标记来探索可能有更好得分的延续(即保留 len(list(generation_config._eos_token_tensor)) + 1)
            if generation_config.num_beams > 1:
                if isinstance(generation_config._eos_token_tensor, list):
                    min_tokens_to_keep = len(generation_config._eos_token_tensor) + 1
                elif isinstance(generation_config._eos_token_tensor, torch.Tensor):
                    min_tokens_to_keep = generation_config._eos_token_tensor.shape[0] + 1
                else:
                    min_tokens_to_keep = 2
            else:
                min_tokens_to_keep = 1

            # 以下思路主要借鉴自这个 PR: https://github.com/huggingface/transformers/pull/5420/files
            # 所有采样器可以在 `generation_utils_samplers.py` 中找到
            # 如果配置了温度且不为 1,则添加温度 logits 调整器
            if generation_config.temperature is not None and generation_config.temperature != 1.0:
                processors.append(TemperatureLogitsWarper(generation_config.temperature))
            # 如果配置了 top-k 采样且不为 0,则添加 top-k logits 调整器
            if generation_config.top_k is not None and generation_config.top_k != 0:
                processors.append(
                    TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)
                )
            # 如果配置了 top-p 采样且小于 1,则添加 top-p logits 调整器
            if generation_config.top_p is not None and generation_config.top_p < 1.0:
                processors.append(
                    TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)
                )
            # 如果配置了最小概率阈值,则添加最小概率 logits 调整器
            if generation_config.min_p is not None:
                # 在温度缩放后应用(见 https://github.com/ggerganov/llama.cpp/pull/3841#issuecomment-2073826084)
                processors.append(
                    MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep)
                )
            # 如果配置了典型概率采样且小于 1,则添加典型概率 logits 调整器
            if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
                processors.append(
                    TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
                )
            # 如果配置了 epsilon 截断且在 0 到 1 之间,则添加 epsilon logits 调整器
            if generation_config.epsilon_cutoff is not None and 0.0 < generation_config.epsilon_cutoff < 1.0:
                processors.append(
                    EpsilonLogitsWarper(
                        epsilon=generation_config.epsilon_cutoff, min_tokens_to_keep=min_tokens_to_keep
                    )
                )
            # 如果配置了 eta 截断且在 0 到 1 之间,则添加 eta logits 调整器
            if generation_config.eta_cutoff is not None and 0.0 < generation_config.eta_cutoff < 1.0:
                processors.append(
                    EtaLogitsWarper(
                        epsilon=generation_config.eta_cutoff, min_tokens_to_keep=min_tokens_to_keep, device=device
                    )
                )

        # 水印处理应该在所有 logits 处理完成后进行(见 #34630)
        if generation_config.watermarking_config is not None:
            processors.append(
                generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
            )

        # `LogitNormalization` 应该始终是最后一个 logit 处理器(如果存在)
        if generation_config.renormalize_logits is True:
            processors.append(LogitNormalization())
        return processors
transformers/src/transformers/generation/utils.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/982060.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

雷池WAF的为什么选择基于Docker

Docker 是一种开源的容器化平台&#xff0c;可以帮助开发人员将应用程序及其所有依赖项打包到一个称为容器的独立、可移植的环境中。Docker 的核心概念包括以下几点&#xff1a; 容器&#xff1a;Docker 使用容器来封装应用程序及其依赖项&#xff0c;使其能够在任何环境中都能…

解决docker认证问题 failed to authorize: failed to fetch oauth token

报错信息[bash1]解决方案 全局代理打开“buildkit”: false &#xff0c;见[图1] [bash1] >docker build -t ffpg . [] Building 71.8s (3/3) FINISHED docker:desktop-linux> [internal] load bui…

LINUX网络基础 [一] - 初识网络,理解网络协议

目录 前言 一. 计算机网络背景 1.1 发展历程 1.1.1 独立模式 1.1.2 网络互联 1.1.3 局域网LAN 1.1.4 广域网WAN 1.2 总结 二. "协议" 2.1 什么是协议 2.2 网络协议的理解 2.3 网络协议的分层结构 三. OSI七层模型&#xff08;理论标准&#xff09; …

【Docker】容器安全之非root用户运行

【Docker】容器安全之非root用户运行 1. 场景2. 原 Dockerfile 内容3. 整改结果4. 非 root 用户带来的潜在问题4.1 文件夹读写权限异常4.2 验证文件夹权限 1. 场景 最近有个项目要交付&#xff0c;第三方测试对项目源码扫描后发现一个问题&#xff0c;服务的 Dockerfile 都未指…

亚马逊云科技Marketplace(中国区)上架专业服务产品, “云生态连接器”价值凸显

近日&#xff0c;由西云数据运营的亚马逊云科技Marketplace&#xff08;中国区&#xff09;正式支持专业服务产品。此次发布将大幅简化企业对云专业服务的采购流程&#xff0c;实现云软件从规划、部署到支持的全生命周期管理&#xff0c;同时也为合作伙伴提供了更多的销售机会。…

鸿蒙启动页开发

鸿蒙启动页开发 1.1 更改应用名称和图标 1.更改应用图标 找到moudle.json5文件&#xff0c;找到应用启动的EntryAbility下面的icon,将原来的图标改成自己设置的即可 2.更改应用名称 3.效果展示 2.1 广告页面开发 3.1 详细介绍 3.1.1 启动页面 import { PrivacyDialog } fr…

HCIA—IP路由静态

一、概念及作用 1、概念&#xff1a;IP路由是指在IP网络中&#xff0c;数据从源节点到目的节点所经过的路径选择和数据转发的过程。 2、作用 ①实现网络互联&#xff1a;使不同网段的设备能够相互通信&#xff0c;构建大规模的互联网络 ②优化网络拓扑&#xff1a;根据网络…

【计算机网络入门】初学计算机网络(十一)重要

目录 1. CIDR无分类编址 1.1 CIDR的子网划分 1.1.1 定长子网划分 1.1.2 变长子网划分 2. 路由聚合 2.1 最长前缀匹配原则 3. 网络地址转换NAT 3.1 端口号 3.2 IP地址不够用&#xff1f; 3.3 公网IP和内网IP 3.4 NAT作用 4. ARP协议 4.1 如何利用IP地址找到MAC地址…

机器视觉开发教程——封装Halcon通用模板匹配工具【含免费教程源码】

目录 引言前期准备Step1 设计可序列化的输入输出集合【不支持多线程】Step2 设计程序框架1、抽象层【IProcess】2、父类【HAlgorithm】3、子类【HFindModelTool】 Step3 设计UI结果展示 引言 通过仿照VisionPro软件二次开发Halcon的模板匹配工具&#xff0c;便于在客户端软件中…

【Linux跬步积累】—— 线程池详解(有源代码)

文章目录 一、如何实现一个线程1、基本结构2、实现成员函数3、演示4、代码总汇Thread.hppMain.cc 二、如何封装线程池1、设计成员变量2、构造函数与析构函数3、初始化4、启动与回收5、主线程放入任务6、子线程读取任务7、终止线程池 三、测试四、线程池总代码1、ThreadPool.hpp…

【Linux】自定协议和序列化与反序列化

目录 一、序列化与反序列化概念 二、自定协议实现一个加法网络计算器 &#xff08;一&#xff09;TCP如何保证接收方的接收到数据是完整性呢&#xff1f; &#xff08;二&#xff09;自定义协议 &#xff08;三&#xff09;自定义协议的实现 1、基础类 2、序列化与反序列…

hive之LEAD 函数详解

1. 函数概述 LEAD 是 Hive 中的窗口函数&#xff0c;用于获取当前行之后指定偏移量处的行的值。常用于分析时间序列数据、计算相邻记录的差异或预测趋势。 2. 语法 LEAD(column, offset, default) OVER ([PARTITION BY partition_column] [ORDER BY order_column [ASC|DESC]…

ZYNQ-PL学习实践(二)按键和定时器控制LED闪烁灯

ZYNQ-PL学习实践&#xff08;二&#xff09;按键和定时器控制LED闪烁灯&#xff09; 1 创建工程2 verilog 代码3 约束4 综合5 生成bit总结 1 创建工程 2 verilog 代码 添加key_led.v 文件&#xff0c; module key_led(input sys_clk , //系统时钟50MHzinput …

【Python爬虫】利用代理IP爬取跨境电商AI选品分析

引言 随着DeepSeek的流行&#xff0c;越来越多的用户开始尝试将AI工具融入到日常工作当中&#xff0c;借助AI的强大功能提高工作效率。最近又掀起了一波企业出海的小高潮&#xff0c;那么如果是做跨境电商业务&#xff0c;怎么将AI融入工作流中呢&#xff1f;在做跨境电商的时候…

设计一个SVF下载器之一:整体思路

CPLD或者FPGA开发工具会生成SVF文件用以通过JTAG口配置CPLD或者FPGA。这里有些基本控制JTAG状态机的指令&#xff0c;其实就是主要两条SIR和SDR分别实现对IR寄存器和DR寄存器的写。 这样我们的这个下载器的基本工作变成了解析SVF文件之后对JTAG的TAP状态机进行操作实现对IR和D…

计算机视觉算法实战——图像配准(主页有源码)

✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连✨ ​ ​​​ 1. 领域简介 图像配准&#xff08;Image Registration&#xff09;是计算机视觉中的一个重要研究方向&#xff0c;旨在将两幅或多幅…

ArcGIS操作:07 绘制矢量shp面

1、点击目录 2、右侧显示目录 3、选择要存储的文件夹&#xff0c;新建shp 4、定义名称、要素类型、坐标系 5、点击开始编辑 6、点击创建要素 7、右侧选择图层、创建面 8、开始绘制&#xff0c;双击任意位置结束绘制

靶场(二)---靶场心得小白分享

开始&#xff1a; 看一下本地IP 21有未授权访问的话&#xff0c;就从21先看起 PORT STATE SERVICE VERSION 20/tcp closed ftp-data 21/tcp open ftp vsftpd 2.0.8 or later | ftp-anon: Anonymous FTP login allowed (FTP code 230) |_Cant get dire…

一周学会Flask3 Python Web开发-WTForms表单验证

锋哥原创的Flask3 Python Web开发 Flask3视频教程&#xff1a; 2025版 Flask3 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili 我们可以通过WTForms表单类属性的validators属性来实现表单验证。 常用的WTForms验证器 验证器说明DataRequired(messageNo…

C 语 言 --- 猜 数 字 游 戏

C 语 言 --- 猜 数 字 游 戏 代 码 全 貌 与 功 能 介 绍游 戏 效 果 展 示游 戏 代 码 详 解头 文 件 引 入菜单函数游 戏 逻 辑 函 数 gamerand 函 数 详 解逻 辑 函 数 game 主 函 数 总结 &#x1f4bb;作 者 简 介&#xff1a;曾 与 你 一 样 迷 茫&#xff0c;现 以 经 验…