Logits ProcessorでQwen(LLM)のハルシネーション対策

要約(先に結論を読みたい方へ)

QwenをファインチューニングしてJSON出力などの特定用途に使う際、LLM特有のトークン繰り返しハルシネーションが発生することがある。repetition_penaltyno_repeat_ngram_sizeという標準的なパラメータで対応できるケースもあるが、副作用が強すぎて実用に耐えない場面も多い。

本記事では、TransformersのカスタムLogits Processorを実装し、「JSONの特定フィールド(Content)内に限定して繰り返し制御を適用する」という粒度の細かいアプローチを紹介する。この手法により、必要な箇所だけにno-repeat-ngramを適用し、副作用を最小化しながらハルシネーションを抑制することができた。

本記事で得られること:

  • Qwen(およびTransformers互換モデル)でのハルシネーション発生メカニズムの理解
  • repetition_penalty / no_repeat_ngram_size の限界と副作用
  • カスタム LogitsProcessor の実装方法(再現可能なコード付き)
  • 本番環境で使えるデバッグ機能の組み込み方

1. Qwenファインチューニングで起きた繰り返しハルシネーションとは

Qwenは精度・速度のバランスに優れたLLMであり、ベースモデルとしての安定性は高い。しかし、特定のタスクに向けてファインチューニングを施すと、元のモデルが持っていた分布が崩れ、推論時に同じトークンやフレーズを延々と繰り返す「ループ状ハルシネーション」が現れることがある。

具体的な症状

たとえば、以下のようなJSON構造を出力させるタスクを想定する。

json
{
  "Title": "業務改善提案",
  "Content": "会議の効率化を図るため、..."
}

ファインチューニング後のモデルが、Contentフィールドの途中から次のような出力を生成することがある。

text
"Content": "会議の効率化を図るため、アジェンダを事前に共有し、アジェンダを事前に共有し、アジェンダを事前に共有し、アジェンダを事前に共有し..."

この現象はいくつかの要因から発生する:

  • 学習データの分布の偏り:特定フレーズが繰り返されるデータがあると、モデルがそのパターンを学習してしまう
  • 温度パラメータとの相互作用:低温度(greedy decoding寄り)の設定では、一度高スコアなトークンが選ばれると抜け出せなくなる
  • ファインチューニングによる過学習:少量データで学習すると、汎化性能が落ちて特定パターンへの収束が起きやすい

2. 標準パラメータの限界

repetition_penalty

Transformersのrepetition_penaltyは、すでに生成されたトークンのlogitにペナルティを与えることで繰り返しを抑制する。値が1.0でペナルティなし、1.3程度で効果が出始める。

python
generation_config = {
    "repetition_penalty": 1.3,
}

効果はある。しかし問題は、「どのトークンが"繰り返し"と見なされるか」のスコープが生成済みシーケンス全体であること。つまり、正当に繰り返す必要がある語(助詞、接続詞、固有名詞など)にもペナルティが入り、文章が不自然になるリスクがある。

no_repeat_ngram_size

no_repeat_ngram_size=3とすると、過去に出現したトライグラムと同じシーケンスの出力を完全に禁止する。強力な制御ができる一方で、副作用も強力だ。

python
generation_config = {
    "no_repeat_ngram_size": 3,
}

副作用の例:

  • 「ありがとうございます。よろしくお願いします。」→ 同じ挨拶を複数回使えなくなる
  • 箇条書きで同じ文末表現(「〜です」「〜ます」)を繰り返せなくなる
  • 構造化されたデータや定型文を出力したいのに、意図した繰り返しが禁止される

全体に適用するから問題が起きる。本当に繰り返しを制御したいのは特定のフィールドだけなのに、モデルの出力全体に制御がかかってしまう。


3. 解決策:カスタムLogits Processorの設計思想

TransformersにはLogits Processorという仕組みがあり、トークン生成の各ステップでlogit値を加工するカスタムクラスを差し込むことができる。

Logits Processorとは(用語解説)

用語説明
Logitsモデルが各トークンに割り当てる生のスコア(softmax前の値)
LogitsProcessor生成ステップごとにlogitを変換するクラス。-infを設定すると実質的にそのトークンの出力を禁止できる
LogitsProcessorList複数のプロセッサをチェーンするコンテナ
ngramn個の連続したトークンの列。trigram(n=3)であれば3トークンの組

設計のポイント

今回実装したContentNoRepeatNGramLogitsProcessorの設計思想は以下のとおり:

  1. 生成されたテキストを逐次デコードし、現在JSONのContentフィールド内にいるかを判定する
  2. Contentフィールド内にいる場合のみ、no-repeat-ngramを適用する
  3. フィールド外(Title、構造的なJSON記述など)では制御をかけない

「どこを生成しているか」をリアルタイムに把握して制御を切り替えるアーキテクチャは、構造化出力を行うLLMの推論制御において汎用的に使えるパターンだ。

もっと詳しくLLMの内部構造を知りたい方へ

以下の記事で、LLMの理解に役立つ書籍を紹介しています。
https://neosophie.com/ja/blog/20260311-books


4. 実装コードと解説

以下がフルの実装コードだ。Transformers互換モデルであればQwenに限らず使用できる。

python
from __future__ import annotations

from datetime import datetime
from typing import Optional

import torch
from transformers.generation.logits_process import (
    LogitsProcessor,
    LogitsProcessorList,
    _calc_banned_ngram_tokens,
)


class ContentNoRepeatNGramLogitsProcessor(LogitsProcessor):
    """JSONのContentフィールド内でのみno-repeat-ngramを適用するLogitsProcessor"""

    def __init__(
        self,
        tokenizer,
        ngram_size: int,
        decode_max_tokens: int = 2048,
        debug: bool = False,
    ):
        if not isinstance(ngram_size, int) or ngram_size <= 0:
            raise ValueError(
                f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}"
            )
        if not isinstance(decode_max_tokens, int) or decode_max_tokens <= 0:
            raise ValueError(
                "`decode_max_tokens` has to be a strictly positive integer, "
                f"but is {decode_max_tokens}"
            )
        self.tokenizer = tokenizer
        self.ngram_size = ngram_size
        self.decode_max_tokens = decode_max_tokens
        self.debug = debug
        self._debug_calls = 0
        self._debug_hits = 0

    @staticmethod
    def _extract_open_content_text(decoded_text: str) -> Optional[str]:
        """
        デコード済みテキストから、開いた状態のContentフィールドのテキストを抽出する。
        閉じていれば("が見つかれば)Noneを返す。
        """
        marker = '"Content":"'
        marker_index = decoded_text.rfind(marker)
        if marker_index == -1:
            return None

        content_start = marker_index + len(marker)
        escaped = False
        for index in range(content_start, len(decoded_text)):
            char = decoded_text[index]
            if escaped:
                escaped = False
                continue
            if char == "\\":
                escaped = True
                continue
            if char == '"':
                # 閉じ引用符が見つかった→Content終了済み
                return None
        return decoded_text[content_start:]

    @staticmethod
    def _debug_timestamp() -> str:
        return datetime.now().strftime("%H:%M:%S")

    def __call__(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        self._debug_calls += 1
        num_batch_hypotheses = scores.shape[0]
        scores_processed = scores.clone()

        for i in range(num_batch_hypotheses):
            # デコードするトークン数を制限してパフォーマンスを確保
            decode_ids = input_ids[i, -self.decode_max_tokens :]
            decoded_text = self.tokenizer.decode(
                decode_ids,
                skip_special_tokens=False,
                clean_up_tokenization_spaces=False,
            )

            # Contentフィールドが開いているか確認
            content_text = self._extract_open_content_text(decoded_text)
            if content_text is None:
                if self.debug:
                    print(
                        f"[{self._debug_timestamp()}][ContentNoRepeat] "
                        f"content_text is None. {decoded_text[-80:]}"
                    )
                continue

            if self.debug:
                preview = content_text[-80:].replace("\n", "\\n")
                print(
                    f"[{self._debug_timestamp()}][ContentNoRepeat] "
                    f"step={self._debug_calls} batch={i} "
                    f"content_chars={len(content_text)} tail={preview!r}"
                )

            # Contentフィールドのテキストをトークン化
            content_ids = self.tokenizer.encode(
                content_text,
                add_special_tokens=False,
            )
            if len(content_ids) + 1 < self.ngram_size:
                # まだngram_sizeに達していない→制限不要
                continue

            content_input_ids = input_ids.new_tensor(content_ids).unsqueeze(0)
            banned_batch_tokens = _calc_banned_ngram_tokens(
                self.ngram_size,
                content_input_ids,
                1,
                content_input_ids.shape[-1],
            )

            banned_tokens = banned_batch_tokens[0]
            if banned_tokens:
                self._debug_hits += 1
                if self.debug:
                    print(
                        f"[{self._debug_timestamp()}][ContentNoRepeat] "
                        f"banned {len(banned_tokens)} token(s) "
                        f"for batch={i} at step={self._debug_calls}"
                    )
                # 禁止トークンのlogitを-infに設定→実質的に出力不可にする
                scores_processed[i, banned_tokens] = -float("inf")

        return scores_processed


class ContentNoRepeatGenerationMixin:
    """CLIやモデルクラスに組み込むためのMixin"""

    @staticmethod
    def add_content_no_repeat_cli_args(parser) -> None:
        parser.add_argument(
            "--content_no_repeat_ngram_size",
            type=int,
            default=0,
            help='JSONの"Content"フィールド内にのみno-repeat-ngramを適用するサイズ(0で無効)',
        )
        parser.add_argument(
            "--content_no_repeat_decode_max_tokens",
            type=int,
            default=1024,
            help='Contentフィールドの検出に使う最大トークン数',
        )
        parser.add_argument(
            "--content_no_repeat_debug",
            action="store_true",
            help='デバッグログを出力する',
        )

    @staticmethod
    def build_content_no_repeat_logits_processor(
        tokenizer,
        content_no_repeat_ngram_size: int = 0,
        content_no_repeat_decode_max_tokens: int = 2048,
        content_no_repeat_debug: bool = False,
    ) -> Optional[LogitsProcessorList]:
        processors = LogitsProcessorList()
        if content_no_repeat_ngram_size and content_no_repeat_ngram_size > 0:
            processors.append(
                ContentNoRepeatNGramLogitsProcessor(
                    tokenizer=tokenizer,
                    ngram_size=content_no_repeat_ngram_size,
                    decode_max_tokens=content_no_repeat_decode_max_tokens,
                    debug=content_no_repeat_debug,
                )
            )
        return processors or None

コードの各部の解説

_extract_open_content_text(静的メソッド)

生成済みテキストから "Content":" というマーカーを探し、その後のテキストを返す。ポイントはエスケープシーケンスを正しく処理している点だ。JSON内では \" のようにバックスラッシュでエスケープされた引用符が登場するため、単純に " を検索するだけでは誤検知する。

decode_max_tokens

生成が長くなればなるほど、毎ステップ全トークンをデコードするコストが高くなる。decode_max_tokensで直近のトークンだけを対象にすることで、推論速度を現実的な範囲に抑えている。

_calc_banned_ngram_tokens

Transformersの内部関数で、指定したngram_sizeに基づいて禁止トークンリストを計算する。private API(アンダースコアプレフィックス)なのでバージョンによってインターフェースが変わる可能性はあるが、現時点ではこれが最もシンプルな実装だ。


5. 使い方と再現手順

インストール要件

bash
pip install transformers torch

基本的な使い方

python
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.logits_process import LogitsProcessorList

# モデルとトークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B-Instruct")

# LogitsProcessorの構築
logits_processor = ContentNoRepeatGenerationMixin.build_content_no_repeat_logits_processor(
    tokenizer=tokenizer,
    content_no_repeat_ngram_size=4,       # 4-gramで繰り返しを禁止
    content_no_repeat_decode_max_tokens=1024,
    content_no_repeat_debug=True,          # 開発時はTrueで動作確認
)

# 推論
inputs = tokenizer("JSONを生成してください:", return_tensors="pt")
generation_config = {
    "max_new_tokens": 512,
    "temperature": 0.7,
    "do_sample": True,
}

output_ids = model.generate(
    **inputs,
    **generation_config,
    logits_processor=logits_processor,
)

print(tokenizer.decode(output_ids[0], skip_special_tokens=True))

ngram_sizeの目安

ngram_size特性
2同じ2トークンの繰り返しを禁止。かなり強い制約
3バランスが良い。まず試すならここから
4緩め。長めのフレーズ繰り返しのみを禁止
5以上ほぼ制約なしに近い

一般的にはngram_size=3ngram_size=4から試し始め、生成品質を見ながら調整するのがよい。


6. 実験で確認できた効果と注意点

効果

  • Contentフィールド内の繰り返しループが顕著に減少した。特に同じ文が3〜5回繰り返されるパターンが消えた
  • JSON構造部分("Title":, "Content": などのキー名)には制御がかからないため、出力の整合性に影響しない
  • debug=Trueにすることで、どのステップで何のトークンが禁止されているか確認でき、チューニングが容易

注意点

  • 毎ステップのデコードコストがあるdecode_max_tokensを小さくしすぎると検出精度が下がり、大きすぎると推論速度に影響する。1024程度が現実的なバランスだ
  • _calc_banned_ngram_tokensはTransformersのプライベートAPIであるため、バージョンアップデート時に動作確認が必要
  • バッチサイズが大きい場合は各バッチでデコードが走るため、単一バッチでの使用が前提のケースが多い

7. FAQ

Q. QwenだけでなくLlama・Mistralでも使えますか?

A. はい。TransformersのLogitsProcessor APIはモデルアーキテクチャに依存しない。model.generate()をサポートしているモデルであれば、同じコードで動作する。ただしトークナイザーの実装差異(特殊トークンの扱い等)には注意が必要だ。

Q. Content以外のフィールドにも適用したい場合は?

A. _extract_open_content_textのマーカー文字列('"Content":"')を変更すればよい。複数フィールドに対応したい場合は、マーカーのリストを受け取る形にクラスを拡張するアプローチが取れる。

Q. vLLMやTGI(Text Generation Inference)でも使えますか?

A. vLLMにはLogitsProcessorの差し込み口があり、一部バージョンではTransformers互換のインターフェースで利用可能だ。TGIは現時点でカスタムLogitsProcessorのサポートが限定的なため、要確認。

Q. repetition_penaltyと組み合わせることはできますか?

A. できる。repetition_penaltyをベースに設定しつつ、Contentフィールドには本プロセッサを追加するという組み合わせは有効だ。ただし制約が二重にかかるため、過剰制御にならないようngram_sizeを大きめにするか、repetition_penaltyを控えめ(1.1程度)にすることを推奨する。

Q. ファインチューニングのデータ品質で解決できませんか?

A. 理想はそうだが、現実には学習データのクリーニングだけで完全に解決しないケースも多い。推論時の制御と学習データの改善は補完関係にあり、両方を並行して取り組むのがベストだ。

Q. ngram_sizeを大きくすれば副作用なく使えますか?

A. 大きくするほど副作用は減るが、繰り返し制御の効果も薄れる。ngram_size=6以上ではほぼ制御が機能しなくなることが多い。具体的な繰り返しのパターンをデバッグモードで観察しながら、最小有効な値を探るのが正攻法だ。


まとめ

Qwenのファインチューニングで生じたJSON Contentフィールドの繰り返しハルシネーションに対し、カスタムLogits Processorによって対象フィールドを絞った制御を実装した。標準パラメータのno_repeat_ngram_sizeとの比較でわかるように、「全体に強制適用する」のではなく「問題が起きているスコープだけに適用する」という設計が、副作用を避けながら実用的な解決策になる。

LLMの推論制御はまだ手探りの部分が多いが、Logits ProcessorというAPIレイヤーは非常に柔軟で、同様のパターンで他の制御ロジックにも応用できる。ぜひ手元の環境で試してほしい。

関連するブログ

この記事に近いテーマのブログをピックアップしています。