IBM Granite Speech 1Bを日本語音声でファインチューニングしてCER 20%以上に改善した話

要約:IBM Granite Speech(granite-4.0-1b-speech)を100時間の日本語音声データでファインチューニングし、CERを0.37から0.14まで改善しました。公式スクリプトのProjector+LoRAのみの学習では精度改善に限界があり、lm_headとLanguage Modelの後ろから8層を追加で学習させることが最大要因になりました。Qwen3-ASR-1.7B(CER 0.14)と同等の精度を1Bパラメータで達成しています。

はじめに:なぜGranite Speechをファインチューニングするのか

2025年後半から2026年にかけて、日本語ASR(自動音声認識)の精度競争は一気に加速しています。Qwen3-ASR、ReazonSpeechなど有力なモデルが揃うなか、IBMが公開したgranite-4.0-1b-speechは英語中心のトレーニングが主体であり、日本語ではそのままでは十分な精度が出ません。

一方、Granite Speechはアーキテクチャが比較的シンプルで、Projector(音声エンコーダとLLMを繋ぐ変換層)とLoRA(Language Modelに挿入されたアダプタ)という明確なファインチューニングポイントを持っています。公式のColabノートブックも提供されており、入門しやすい点が魅力です。

ただし、公式の設定のままでは精度改善に天井があります。本記事ではその限界を突破するために行ったパラメータ拡張戦略と、100時間の日本語音声データを使ったファインチューニングの知見を詳しく解説します。

事前知識:Granite Speechのアーキテクチャ

Granite Speech(granite-4.0-1b-speech)は大体以下の3コンポーネントで構成されます。

コンポーネント役割
Speech Encoder音声波形 → 高解像度の音響表現(Conformer + CTC)
Speech Projector (Q-Former)音声表現を圧縮 + 意味抽出 + LLM空間へ変換
Language Modelテキスト生成・意味解釈

**LoRA(Low-Rank Adaptation)**とは、Language Modelの各Attention層に小さな低ランク行列を挿入し、元の重みを凍結したまま少ないパラメータで効率的にファインチューニングする手法です。

**CER(Character Error Rate)**とは文字単位の誤り率です。日本語は単語分かち書きが難しいため、WER(Word Error Rate)ではなくCERで評価するのが一般的です。値が低いほど精度が高く、0.14 = 14%の文字が誤認識されることを意味します。

公式ファインチューニングの限界

公式スクリプトの学習対象パラメータはシンプルです。
https://colab.research.google.com/github/ibm-granite/granite-speech-models/blob/main/notebooks/fine_tuning_granite_speech.ipynb

python
for n, p in model.named_parameters():
    # Projector と LoRA 層のみ学習
    p.requires_grad = "projector" in n or "lora" in n

この設定でファインチューニングを走らせると、最初のうちはCERが改善されます。しかしデータ量が増えても改善の頭打ちが早いという問題があります。

直感的に考えると理解しやすいです。Projector+LoRAは「音声→テキスト変換の橋渡し」を担う部分ですが、Language Model本体の「日本語の語彙・文法・文脈理解」は一切更新されません。100時間規模のデータを投入しても、言語モデル側の日本語表現能力の不足が精度のボトルネックになるのです。

特に日本語特有の課題である漢字・固有名詞・句読点の出力は、Language Model側の能力に大きく依存します。

学習対象パラメータを拡張する

lm_headとLanguage Modelの後ろn層を追加する

そこで以下の戦略を取りました。

  1. lm_head(最終的なトークン予測層)を学習対象に追加
  2. Language Modelの後ろからN層のTransformer層を学習対象に追加
python
def should_train_parameter(
    name: str, *, train_last_n_layers: int, total_lm_layers: int
) -> bool:
    # Projector / LoRA / lm_head は無条件で学習
    if "projector" in name or "lora" in name or "lm_head" in name:
        return True

    if train_last_n_layers <= 0:
        return False

    # language_model.model.layers.{index}.* の形式を解析
    layer_prefix = "language_model.model.layers."
    if not name.startswith(layer_prefix):
        return False

    remainder = name[len(layer_prefix):]
    layer_index_str = remainder.split(".", 1)[0]
    if not layer_index_str.isdigit():
        return False

    layer_index = int(layer_index_str)
    first_trainable_layer = max(total_lm_layers - train_last_n_layers, 0)
    return layer_index >= first_trainable_layer

実際の適用はモデルロード後に以下のように行います。

python
total_lm_layers = len(model.language_model.model.layers)

for name, param in model.named_parameters():
    param.requires_grad = should_train_parameter(
        name,
        train_last_n_layers=args.train_last_n_layers,
        total_lm_layers=total_lm_layers,
    )

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"学習パラメータ数: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")

最適なn層数の選び方

n層の選択はデータ量に依存します。少量データで過剰な層を開放すると過学習のリスクが高まります。

データ量の目安推奨n層
~10時間0〜2(公式設定 or 最小限)
10〜50時間2〜4
50〜100時間4〜8
100時間超8〜全層

今回は100時間の日本語音声データを使用し、n=8がベストという結果になりました。

ハイパーパラメータと学習設定

python
# 今回の学習設定
batch_size = 8
learning_rate = 5e-5
train_last_n_layers = 8
num_epochs = 3.0
warmup_ratio = 0.0
lr_scheduler_type = "cosine"
grad_accum_steps = 2

句読点出力の重要性

今回のファインチューニングで見落とされがちだが実は重要な改善点が句読点の出力です。

元のgranite-4.0-1b-speechは句読点なしのテキストを出力する傾向があります。今回の訓練データに句読点付きの書き起こしを使用したことで、モデルが句読点を適切に出力するようになりました。

実用的なASRシステムでは、句読点なしのテキストはそのまま使いにくく、後処理で句読点を付与する必要がありました。この改善によりパイプラインが大幅にシンプルになります。

ベンチマーク結果

評価は日本語ASRベンチマーク記事でも使用した内部データセットを使用しています。

モデルCER備考
ibm-granite/granite-4.0-1b-speech(ベースライン)0.37ファインチューニング前
ibm-granite/granite-4.0-1b-speech(ファインチューニング後)0.141Bパラメータで同等精度を達成
qwen/qwen3-asr-1.7b0.14比較対象(1.7Bパラメータ)

Qwen3-ASR-1.7Bは1.7Bパラメータを持つモデルです。それより小さい1Bパラメータのモデルで同等のCERを達成できたことは、ファインチューニングの有効性を示しています。

残る課題

  • 漢字・固有名詞の誤認識:人名・地名・専門用語は依然として誤りが多い
  • Qwen超えの壁:データ数をさらに増やせば追い抜ける可能性はあるが未達成
  • 語尾の切れや長音の誤りも散見される

データ品質フィルタリングの工夫

100時間のデータを投入する際、品質管理が精度に直結します。本スクリプトではマニフェストファイルに記録されたCER・音声長・テキスト長などでフィルタリングする仕組みを実装しています。

python
# フィルタリング条件の例(local manifest使用時)
--min-duration-sec 10.0    # 10秒未満の短い音声を除外
--max-duration-sec 30.0    # 30秒超の長い音声を除外
--min-text-len 45          # テキストが短すぎるサンプルを除外
--max-cer 0.4              # WhisperとのCERが40%超の低品質書き起こしを除外

ファインチューニングの精度はデータ品質に強く依存します。ノイズの多い音声や、書き起こしの誤りが多いサンプルはフィルタリングして除外することを推奨します。事前にASRで書き起こしてCERを計算し、外れ値を除去するパイプラインを構築すると効果的です。

TestCerCallbackによるチェックポイント管理

学習中にテストCERを自動測定し、最良チェックポイントのみを保存する仕組みを実装しています。長い学習ではストレージ節約と最良モデルの確実な保持に有効です。

python
# チェックポイントごとにtest CERを評価し、最良のみ保存
trainer.add_callback(TestCerCallback(
    processor=processor,
    test_dataset=test_dataset,
    language="ja",
    save_best_test_cer_only=True,  # 最良チェックポイントのみ保持
    ...
))

--save-best-test-cer-onlyフラグを有効にすると、各保存ステップでtest CERを評価し、ベストを更新しない限りチェックポイントを自動削除します。

完全なファインチューニングスクリプト

以下のスクリプトを実行するには、HuggingFace transformersのmainブランチが必要です(granite_speechモジュールが安定版にまだ含まれていない可能性があります)。

python
"""
Finetuning script for IBM Granite Speech models.

Based on:
https://colab.research.google.com/github/ibm-granite/granite-speech-models/blob/main/notebooks/fine_tuning_granite_speech.ipynb

Requirements:
    pip install git+https://github.com/huggingface/transformers.git
    pip install -U datasets peft accelerate evaluate whisper tqdm
"""

import argparse
import json
import shutil
from pathlib import Path

import jiwer
import torch
import tqdm
import evaluate
import soundfile as sf
from datasets import load_dataset, concatenate_datasets, Audio
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer, TrainerCallback
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.granite_speech import (
    GraniteSpeechForConditionalGeneration,
    GraniteSpeechProcessor,
)


# ---------------------------------------------------------------------------
# Data preprocessing
# ---------------------------------------------------------------------------


def process_gigaspeech_transcript(text: str) -> str:
    text = text.replace(" <COMMA>", ",")
    text = text.replace(" <PERIOD>", ".")
    text = text.replace(" <QUESTIONMARK>", "?")
    text = text.replace(" <EXCLAMATIONPOINT>", "!")
    return text.lower()


def resolve_audio_path(
    audio_path: str, manifest_path: Path, audio_root: Path | None
) -> str:
    audio = Path(audio_path)
    if audio.is_absolute():
        return str(audio)
    if audio_root is not None:
        return str((audio_root / audio).resolve())
    return str((manifest_path.parent / audio).resolve())


def resolve_row_audio_path(
    row: dict, audio_column: str, manifest_path: Path, audio_root: Path | None
) -> str:
    if audio_column not in row:
        raise ValueError(f"{manifest_path} must contain '{audio_column}' column")
    return resolve_audio_path(str(row[audio_column]), manifest_path, audio_root)


def load_cer_results_index(results_jsonl: str | None) -> dict[str, float]:
    if not results_jsonl:
        return {}

    index: dict[str, float] = {}
    results_path = Path(results_jsonl).resolve()
    with results_path.open() as fh:
        for line_no, line in enumerate(fh, start=1):
            line = line.strip()
            if not line:
                continue
            row = json.loads(line)
            audio_path = row.get("audio")
            cer = row.get("cer")
            if not audio_path or cer is None:
                continue
            index[str(Path(audio_path).resolve())] = float(cer)

    print(f"Loaded {len(index)} CER rows from {results_path}")
    return index


def normalize_transcript(text: str, *, dataset_kind: str) -> str:
    normalized = str(text).strip()
    if dataset_kind == "hf":
        return process_gigaspeech_transcript(normalized)
    return normalized


def prep_example(example, tokenizer, dataset_kind: str):
    instruction = "Please transcribe the following audio to text<|audio|>"
    chat = [{"role": "user", "content": instruction}]
    example["prompt"] = tokenizer.apply_chat_template(
        chat,
        add_generation_prompt=True,
        tokenize=False,
    )
    example["text"] = normalize_transcript(
        example["text"],
        dataset_kind=dataset_kind,
    )
    return example


def should_keep_local_record(
    example,
    *,
    min_cer,
    max_cer,
    results_min_cer,
    results_max_cer,
    min_cer_percent,
    max_cer_percent,
    min_duration_sec,
    max_duration_sec,
    min_text_len,
    max_text_len,
    manifest_path,
    audio_root,
    cer_results_index,
):
    manifest_cer = example.get("cer")
    if min_cer is not None and (manifest_cer is None or float(manifest_cer) < min_cer):
        return False
    if max_cer is not None and (manifest_cer is None or float(manifest_cer) > max_cer):
        return False

    results_cer = None
    if cer_results_index:
        audio_path = resolve_audio_path(
            str(example["audio_path"]),
            manifest_path,
            audio_root,
        )
        results_cer = cer_results_index.get(audio_path)
    if results_min_cer is not None and (
        results_cer is None or float(results_cer) < results_min_cer
    ):
        return False
    if results_max_cer is not None and (
        results_cer is None or float(results_cer) > results_max_cer
    ):
        return False

    cer_percent = example.get("cer_percent")
    if min_cer_percent is not None and (
        cer_percent is None or float(cer_percent) < min_cer_percent
    ):
        return False
    if max_cer_percent is not None and (
        cer_percent is None or float(cer_percent) > max_cer_percent
    ):
        return False

    duration_sec = example.get("duration_sec")
    if min_duration_sec is not None and (
        duration_sec is None or float(duration_sec) < min_duration_sec
    ):
        return False
    if max_duration_sec is not None and (
        duration_sec is None or float(duration_sec) > max_duration_sec
    ):
        return False

    text = str(example.get("text", "")).strip()
    if min_text_len is not None and len(text) < min_text_len:
        return False
    if max_text_len is not None and len(text) > max_text_len:
        return False

    return True


def prepare_dataset(ds, processor, *, dataset_kind: str, language: str):
    columns_to_remove = [col for col in ds.column_names if col not in ["audio", "text"]]
    ds = ds.cast_column(
        "audio",
        Audio(
            sampling_rate=processor.audio_processor.sampling_rate,
            decode=False,
        ),
    )
    ds = ds.map(
        prep_example,
        fn_kwargs={
            "tokenizer": processor.tokenizer,
            "dataset_kind": dataset_kind,
        },
        remove_columns=columns_to_remove,
    )
    if dataset_kind == "hf":
        ds = ds.filter(
            lambda x: x["text"] not in ["<other>", "<noise>", "<music>", "<sil>"]
        )
    return ds


# ---------------------------------------------------------------------------
# Collator
# ---------------------------------------------------------------------------


def _extract_audio_array(audio):
    """Load audio without relying on datasets' torchcodec-based decoding."""
    if hasattr(audio, "get_all_samples"):
        samples = audio.get_all_samples()
        return samples.data.squeeze(0).numpy()
    if isinstance(audio, dict):
        if "array" in audio:
            return audio["array"]
        if audio.get("path"):
            audio_array, sampling_rate = sf.read(audio["path"], dtype="float32")
            if audio_array.ndim > 1:
                audio_array = audio_array.mean(axis=1)
            return audio_array, sampling_rate
    return audio


class GraniteCollator:
    def __init__(self, processor, inference_mode: bool = False):
        self.processor = processor
        self.inference_mode = inference_mode
        self.sampling_rate = processor.audio_processor.sampling_rate

    def _prepare_audio(self, audio):
        extracted = _extract_audio_array(audio)
        sampling_rate = self.sampling_rate

        if isinstance(extracted, tuple):
            audio_array, sampling_rate = extracted
        else:
            audio_array = extracted

        audio_tensor = torch.as_tensor(audio_array, dtype=torch.float32)
        if audio_tensor.ndim > 1:
            audio_tensor = audio_tensor.mean(dim=0)
        if sampling_rate != self.sampling_rate:
            audio_tensor = (
                torch.nn.functional.interpolate(
                    audio_tensor.unsqueeze(0).unsqueeze(0),
                    size=int(
                        audio_tensor.shape[-1] * self.sampling_rate / sampling_rate
                    ),
                    mode="linear",
                    align_corners=False,
                )
                .squeeze(0)
                .squeeze(0)
            )
        return audio_tensor.numpy()

    def __call__(self, examples):
        prompts = [ex["prompt"] for ex in examples]
        audios = [self._prepare_audio(ex["audio"]) for ex in examples]

        processed = self.processor(
            prompts, audios, return_tensors="pt", padding=True, padding_side="left"
        )
        input_ids = processed.input_ids
        attention_mask = processed.attention_mask
        labels = None

        if not self.inference_mode:
            targets = [
                ex["text"] + self.processor.tokenizer.eos_token for ex in examples
            ]
            targets = self.processor.tokenizer(
                targets, return_tensors="pt", padding=True, padding_side="right"
            )
            input_ids = torch.cat([input_ids, targets.input_ids], dim=1)
            attention_mask = torch.cat([attention_mask, targets.attention_mask], dim=1)
            labels = targets.input_ids.clone()
            labels[~targets.attention_mask.bool()] = -100
            labels = torch.cat(
                [torch.full_like(processed.input_ids, -100), labels], dim=1
            )

        return BatchFeature(
            data={
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "input_features": processed.input_features,
                "input_features_mask": processed.input_features_mask,
            }
        )


# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------


def build_eval_normalizer(
    language: str,
    *,
    remove_punctuation: bool = True,
    convert_kanji_numbers: bool = True,
):
    if language == "ja":
        import functools

        return functools.partial(
            normalize_japanese_text,
            remove_punctuation=remove_punctuation,
            convert_kanji_numbers=convert_kanji_numbers,
        )
    try:
        from whisper.normalizers import EnglishTextNormalizer
    except ModuleNotFoundError:
        return lambda x: str(x).strip().lower()
    english_normalizer = EnglishTextNormalizer()
    return english_normalizer


def compute_metric(
    model,
    processor,
    dataset,
    *,
    metric_name: str,
    language: str,
    batch_size: int = 16,
    num_workers: int = 8,
):
    collator = GraniteCollator(processor, inference_mode=True)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, collate_fn=collator, num_workers=num_workers
    )
    normalizer = build_eval_normalizer(language)
    metric = evaluate.load(metric_name)

    model = model.eval().cuda()
    all_outputs = []

    for batch in tqdm.tqdm(dataloader, desc="Running inference"):
        batch = batch.to("cuda")
        with torch.inference_mode(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
            outputs = model.generate(
                **batch, max_new_tokens=400, num_beams=1, early_stopping=True
            )
        input_length = batch.input_ids.shape[1]
        outputs = outputs[:, input_length:].cpu()
        for x in outputs:
            all_outputs.append(processor.tokenizer.decode(x, skip_special_tokens=True))

    gt_texts = [normalizer(x) for x in dataset["text"]]
    all_outputs = [normalizer(x) for x in all_outputs]
    return metric.compute(references=gt_texts, predictions=all_outputs)


def compute_predictions(
    model,
    processor,
    dataset,
    *,
    language: str,
    batch_size: int = 16,
    num_workers: int = 8,
):
    collator = GraniteCollator(processor, inference_mode=True)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, collate_fn=collator, num_workers=num_workers
    )
    normalizer = build_eval_normalizer(language)

    model = model.eval().cuda()
    predictions = []

    for batch in tqdm.tqdm(dataloader, desc="Running inference"):
        batch = batch.to("cuda")
        with torch.inference_mode(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
            outputs = model.generate(
                **batch, max_new_tokens=400, num_beams=4, early_stopping=True
            )
        input_length = batch.input_ids.shape[1]
        outputs = outputs[:, input_length:].cpu()
        for x in outputs:
            predictions.append(processor.tokenizer.decode(x, skip_special_tokens=True))

    references = [normalizer(x) for x in dataset["text"]]
    predictions = [normalizer(x) for x in predictions]
    return references, predictions


def build_diff_string(reference: str, prediction: str) -> str:
    import difflib

    matcher = difflib.SequenceMatcher(None, prediction, reference)
    result = []
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == "equal":
            result.append(prediction[i1:i2])
        else:
            result.append(f"(w: {prediction[i1:i2]}|q: {reference[j1:j2]})")
    return "".join(result)


def build_prediction_rows(references, predictions):
    rows = []
    for index, (reference, prediction) in enumerate(
        zip(references, predictions), start=1
    ):
        rows.append(
            {
                "index": index,
                "reference": reference,
                "prediction": prediction,
                "cer": jiwer.cer(reference, prediction),
                "diff": build_diff_string(reference, prediction),
            }
        )
    return rows


def maybe_init_wandb(args):
    if not args.use_wandb:
        return None
    import wandb

    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=args.wandb_run_name,
        config=vars(args),
    )
    return wandb


def maybe_log_wandb_predictions(
    wandb_module,
    *,
    stage: str,
    metric_name: str,
    metric_value: float,
    rows,
    max_rows: int,
):
    if wandb_module is None:
        return
    table = wandb_module.Table(
        columns=["index", "reference", "prediction", "cer", "diff"]
    )
    for row in rows[:max_rows]:
        table.add_data(
            row["index"], row["reference"], row["prediction"], row["cer"], row["diff"]
        )
    wandb_module.log(
        {
            f"{stage}/{metric_name}": metric_value,
            f"{stage}/prediction_table": table,
        }
    )


def prune_checkpoints(output_dir: str, keep_checkpoint_dir: Path | None) -> None:
    keep_path = (
        keep_checkpoint_dir.resolve() if keep_checkpoint_dir is not None else None
    )
    for checkpoint_dir in Path(output_dir).glob("checkpoint-*"):
        if not checkpoint_dir.is_dir():
            continue
        if keep_path is not None and checkpoint_dir.resolve() == keep_path:
            continue
        shutil.rmtree(checkpoint_dir)


class TestCerCallback(TrainerCallback):
    def __init__(
        self,
        *,
        processor,
        test_dataset,
        language: str,
        batch_size: int,
        num_workers: int,
        output_dir: str,
        wandb_module,
        wandb_log_samples: int,
        save_best_test_cer_only: bool,
    ):
        self.processor = processor
        self.test_dataset = test_dataset
        self.language = language
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.output_dir = output_dir
        self.wandb_module = wandb_module
        self.wandb_log_samples = wandb_log_samples
        self.save_best_test_cer_only = save_best_test_cer_only
        self.best_test_cer: float | None = None
        self.best_checkpoint_dir: Path | None = None

    def on_save(self, args, state, control, model=None, **kwargs):
        references, predictions = compute_predictions(
            model,
            self.processor,
            self.test_dataset,
            language=self.language,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )
        test_cer = jiwer.cer(references, predictions)
        prediction_rows = build_prediction_rows(references, predictions)

        checkpoint_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}"
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        metrics = {
            "global_step": state.global_step,
            "test_cer": test_cer,
            "num_samples": len(references),
        }
        (checkpoint_dir / "test_cer.json").write_text(
            json.dumps(metrics, ensure_ascii=False, indent=2) + "\n",
            encoding="utf-8",
        )
        with (checkpoint_dir / "test_predictions.jsonl").open(
            "w", encoding="utf-8"
        ) as f:
            for row in prediction_rows:
                f.write(json.dumps(row, ensure_ascii=False) + "\n")
        with (Path(self.output_dir) / "test_cer_history.jsonl").open(
            "a", encoding="utf-8"
        ) as f:
            f.write(json.dumps(metrics, ensure_ascii=False) + "\n")

        is_best = self.best_test_cer is None or test_cer < self.best_test_cer
        if is_best:
            previous_best_dir = self.best_checkpoint_dir
            self.best_test_cer = test_cer
            self.best_checkpoint_dir = checkpoint_dir
            (Path(self.output_dir) / "best_test_cer.json").write_text(
                json.dumps(
                    {
                        "global_step": state.global_step,
                        "test_cer": test_cer,
                        "checkpoint_dir": str(checkpoint_dir),
                    },
                    ensure_ascii=False,
                    indent=2,
                )
                + "\n",
                encoding="utf-8",
            )
            if (
                self.save_best_test_cer_only
                and previous_best_dir is not None
                and previous_best_dir != checkpoint_dir
                and previous_best_dir.exists()
            ):
                shutil.rmtree(previous_best_dir)
        elif self.save_best_test_cer_only and checkpoint_dir.exists():
            shutil.rmtree(checkpoint_dir)

        maybe_log_wandb_predictions(
            self.wandb_module,
            stage=f"checkpoint_{state.global_step}",
            metric_name="cer",
            metric_value=test_cer,
            rows=prediction_rows,
            max_rows=self.wandb_log_samples,
        )
        if is_best:
            print(f"Test CER at step {state.global_step}: {test_cer * 100:.3f}% (best)")
        else:
            print(f"Test CER at step {state.global_step}: {test_cer * 100:.3f}%")
        return control


def take_first_n(ds, n: int):
    if n is None or n <= 0:
        return ds
    return ds.select(range(min(n, len(ds))))


def _load_single_manifest(manifest_path: Path, audio_root: Path | None, args):
    cer_results_index = getattr(args, "cer_results_index", {})
    dataset = load_dataset("json", data_files={"manifest": str(manifest_path)})[
        "manifest"
    ]
    before_count = len(dataset)
    dataset = dataset.filter(
        should_keep_local_record,
        fn_kwargs={
            "min_cer": args.min_cer,
            "max_cer": args.max_cer,
            "results_min_cer": args.results_min_cer,
            "results_max_cer": args.results_max_cer,
            "min_cer_percent": args.min_cer_percent,
            "max_cer_percent": args.max_cer_percent,
            "min_duration_sec": args.min_duration_sec,
            "max_duration_sec": args.max_duration_sec,
            "min_text_len": args.min_text_len,
            "max_text_len": args.max_text_len,
            "manifest_path": manifest_path,
            "audio_root": audio_root,
            "cer_results_index": cer_results_index,
        },
    )
    if cer_results_index and (
        args.results_min_cer is not None or args.results_max_cer is not None
    ):
        removed_count = before_count - len(dataset)
        print(
            f"Filtered {removed_count} rows from {manifest_path} using external CER thresholds"
        )
    dataset = dataset.map(
        lambda example: {
            "audio": resolve_audio_path(
                example["audio_path"], manifest_path, audio_root
            ),
            "text": str(example["text"]).strip(),
        },
        remove_columns=[
            col for col in dataset.column_names if col not in ["audio_path", "text"]
        ],
    )
    return dataset


def load_local_dataset(args):
    local_manifest = Path(args.local_manifest).resolve()
    audio_root = Path(args.audio_root).resolve() if args.audio_root else None
    args.cer_results_index = load_cer_results_index(args.cer_results_jsonl)

    if local_manifest.is_dir():
        manifest_paths = sorted(local_manifest.rglob("manifest.jsonl"))
        if not manifest_paths:
            raise ValueError(f"No manifest.jsonl files found under {local_manifest}")
        non_empty_manifest_paths = [p for p in manifest_paths if p.stat().st_size > 0]
        skipped_manifest_count = len(manifest_paths) - len(non_empty_manifest_paths)
        print(f"Found {len(manifest_paths)} manifest files under {local_manifest}")
        if skipped_manifest_count:
            print(f"Skipping {skipped_manifest_count} empty manifest files")
        if not non_empty_manifest_paths:
            raise ValueError(
                f"All manifest.jsonl files under {local_manifest} are empty"
            )
        dataset = concatenate_datasets(
            [
                _load_single_manifest(p, audio_root, args)
                for p in non_empty_manifest_paths
            ]
        )
    else:
        if local_manifest.stat().st_size == 0:
            raise ValueError(f"Manifest file is empty: {local_manifest}")
        dataset = _load_single_manifest(local_manifest, audio_root, args)

    if len(dataset) < 3:
        raise ValueError(
            f"Local dataset is too small after filtering: {len(dataset)} rows. Need at least 3."
        )

    split_seed = args.split_seed
    test_size = args.test_split_ratio
    val_size = args.val_split_ratio / (1.0 - test_size)
    first_split = dataset.train_test_split(test_size=test_size, seed=split_seed)
    train_and_val = first_split["train"]
    test_dataset = first_split["test"]
    second_split = train_and_val.train_test_split(test_size=val_size, seed=split_seed)
    train_dataset = second_split["train"]
    val_dataset = second_split["test"]

    train_dataset = take_first_n(train_dataset, args.train_size)
    val_dataset = take_first_n(val_dataset, args.val_size)
    test_dataset = take_first_n(test_dataset, args.test_size)
    return train_dataset, val_dataset, test_dataset


def load_eval_json_dataset(
    path: str,
    *,
    audio_root: str | None,
    audio_column: str,
    text_column: str,
    size_limit: int | None,
):
    manifest_path = Path(path).resolve()
    resolved_audio_root = Path(audio_root).resolve() if audio_root else None

    dataset = load_dataset("json", data_files={"eval": str(manifest_path)})["eval"]
    dataset = dataset.map(
        lambda example: {
            "audio": resolve_row_audio_path(
                example,
                audio_column=audio_column,
                manifest_path=manifest_path,
                audio_root=resolved_audio_root,
            ),
            "text": str(example[text_column]).strip(),
        },
        remove_columns=[
            col
            for col in dataset.column_names
            if col not in [audio_column, text_column]
        ],
    )
    return take_first_n(dataset, size_limit)


def load_training_datasets(args):
    if args.local_manifest:
        print(f"Loading local manifest: {args.local_manifest}")
        return load_local_dataset(args), "local"

    print(f"Loading dataset: {args.dataset} ({args.dataset_subset})")
    dataset = load_dataset(args.dataset, args.dataset_subset)
    train_dataset = take_first_n(dataset["train"], args.train_size)
    val_dataset = take_first_n(dataset["validation"], args.val_size)
    test_dataset = take_first_n(dataset["test"], args.test_size)
    return (train_dataset, val_dataset, test_dataset), "hf"


# ---------------------------------------------------------------------------
# Training parameter selection
# ---------------------------------------------------------------------------


def should_train_parameter(
    name: str, *, train_last_n_layers: int, total_lm_layers: int
) -> bool:
    if "projector" in name or "lora" in name or "lm_head" in name:
        return True

    if train_last_n_layers <= 0:
        return False

    layer_prefix = "language_model.model.layers."
    if not name.startswith(layer_prefix):
        return False

    remainder = name[len(layer_prefix) :]
    layer_index_str = remainder.split(".", 1)[0]
    if not layer_index_str.isdigit():
        return False

    layer_index = int(layer_index_str)
    first_trainable_layer = max(total_lm_layers - train_last_n_layers, 0)
    return layer_index >= first_trainable_layer


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def parse_args():
    parser = argparse.ArgumentParser(description="Finetune IBM Granite Speech")
    parser.add_argument(
        "--model-name",
        default="ibm-granite/granite-4.0-1b-speech",
        help="HuggingFace model ID",
    )
    parser.add_argument(
        "--dataset",
        default="speechcolab/gigaspeech",
        help="HuggingFace dataset ID",
    )
    parser.add_argument(
        "--dataset-subset", default="xs", help="Dataset subset/config name"
    )
    parser.add_argument(
        "--local-manifest",
        default=None,
        help="Path to local ASR manifest jsonl with at least 'audio_path' and 'text'",
    )
    parser.add_argument(
        "--audio-root",
        default=None,
        help="Optional root directory to prepend to relative local audio paths",
    )
    parser.add_argument(
        "--validation-json",
        default=None,
        help="Optional JSON/JSONL validation dataset path with audio/text columns",
    )
    parser.add_argument(
        "--test-json",
        default=None,
        help="Optional JSON/JSONL test dataset path with audio/text columns",
    )
    parser.add_argument(
        "--eval-audio-root",
        default=None,
        help="Optional root directory to prepend to relative evaluation audio paths",
    )
    parser.add_argument("--eval-audio-column", default="audio")
    parser.add_argument("--eval-text-column", default="text")
    parser.add_argument("--train-size", type=int, default=5000)
    parser.add_argument("--val-size", type=int, default=200)
    parser.add_argument("--test-size", type=int, default=200)
    parser.add_argument("--val-split-ratio", type=float, default=0.01)
    parser.add_argument("--test-split-ratio", type=float, default=0.01)
    parser.add_argument("--split-seed", type=int, default=42)
    parser.add_argument("--language", default=None, choices=["en", "ja"])
    parser.add_argument("--eval-metric", default=None, choices=["wer", "cer"])
    parser.add_argument("--min-cer", type=float, default=None)
    parser.add_argument("--max-cer", type=float, default=None)
    parser.add_argument("--results-min-cer", type=float, default=None)
    parser.add_argument("--results-max-cer", type=float, default=None)
    parser.add_argument(
        "--cer-results-jsonl",
        default=None,
        help="Optional JSONL with 'audio' and 'cer'; used with --results-min-cer/--results-max-cer to filter local manifest rows by external CER",
    )
    parser.add_argument("--min-cer-percent", type=float, default=None)
    parser.add_argument("--max-cer-percent", type=float, default=None)
    parser.add_argument("--min-duration-sec", type=float, default=None)
    parser.add_argument("--max-duration-sec", type=float, default=None)
    parser.add_argument("--min-text-len", type=int, default=None)
    parser.add_argument("--max-text-len", type=int, default=None)
    parser.add_argument("--output-dir", default="save_dir")
    parser.add_argument("--use-wandb", action="store_true")
    parser.add_argument("--wandb-project", default="granite-asr")
    parser.add_argument("--wandb-entity", default=None)
    parser.add_argument("--wandb-run-name", default=None)
    parser.add_argument("--wandb-log-samples", type=int, default=50)
    parser.add_argument("--num-epochs", type=float, default=1.0)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--grad-accum-steps", type=int, default=2)
    parser.add_argument("--learning-rate", type=float, default=3e-5)
    parser.add_argument("--warmup-ratio", type=float, default=0.2)
    parser.add_argument("--lr-scheduler-type", type=str, default="cosine")
    parser.add_argument(
        "--train-last-n-layers",
        type=int,
        default=0,
        help="Also train the last N language_model transformer layers in addition to projector/lm_head/LoRA params",
    )
    parser.add_argument("--dataloader-num-workers", type=int, default=16)
    parser.add_argument(
        "--save-steps",
        "--save_steps",
        dest="save_steps",
        type=int,
        default=200,
        help="Save a checkpoint every N update steps",
    )
    parser.add_argument(
        "--save-total-limit",
        "--save_total_limit",
        dest="save_total_limit",
        type=int,
        default=None,
        help="Maximum number of checkpoints to keep",
    )
    parser.add_argument(
        "--save-best-test-cer-only",
        action="store_true",
        help="After each save-step evaluation, keep only the checkpoint with the lowest test CER",
    )
    parser.add_argument("--skip-eval", action="store_true", help="Skip evaluation")
    return parser.parse_args()


def main():
    args = parse_args()
    if args.train_last_n_layers < 0:
        raise ValueError("--train-last-n-layers must be >= 0")
    if args.save_best_test_cer_only and args.skip_eval:
        raise ValueError(
            "--save-best-test-cer-only requires evaluation; remove --skip-eval"
        )
    if args.language is None:
        args.language = "ja" if args.local_manifest else "en"
    wandb_module = maybe_init_wandb(args)

    if args.local_manifest:
        if not (0.0 < args.test_split_ratio < 1.0):
            raise ValueError("--test-split-ratio must be between 0 and 1")
        if not (0.0 < args.val_split_ratio < 1.0):
            raise ValueError("--val-split-ratio must be between 0 and 1")
        if args.val_split_ratio + args.test_split_ratio >= 1.0:
            raise ValueError("--val-split-ratio + --test-split-ratio must be < 1")

    if args.eval_metric is None:
        args.eval_metric = "cer" if args.language == "ja" else "wer"

    # --- Save args ---
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "args.json").write_text(json.dumps(vars(args), indent=2, default=str))

    # --- Load dataset ---
    (train_dataset, val_dataset, test_dataset), dataset_kind = load_training_datasets(
        args
    )
    if args.validation_json:
        print(f"Loading validation dataset: {args.validation_json}")
        val_dataset = load_eval_json_dataset(
            args.validation_json,
            audio_root=args.eval_audio_root,
            audio_column=args.eval_audio_column,
            text_column=args.eval_text_column,
            size_limit=args.val_size,
        )
    if args.validation_json or args.test_json:
        test_json_path = args.test_json or args.validation_json
        print(f"Loading test dataset: {test_json_path}")
        test_dataset = load_eval_json_dataset(
            test_json_path,
            audio_root=args.eval_audio_root,
            audio_column=args.eval_audio_column,
            text_column=args.eval_text_column,
            size_limit=args.test_size,
        )
    print(
        f"Dataset sizes -> train: {len(train_dataset)}, validation: {len(val_dataset)}, test: {len(test_dataset)}"
    )

    # --- Load model & processor ---
    print(f"Loading model: {args.model_name}")
    processor = GraniteSpeechProcessor.from_pretrained(args.model_name)
    model = GraniteSpeechForConditionalGeneration.from_pretrained(
        args.model_name, dtype=torch.bfloat16
    )

    # --- Preprocess datasets ---
    print("Preprocessing datasets...")
    train_dataset = prepare_dataset(
        train_dataset,
        processor,
        dataset_kind=dataset_kind,
        language=args.language,
    )
    val_dataset = prepare_dataset(
        val_dataset,
        processor,
        dataset_kind=dataset_kind,
        language=args.language,
    )
    test_dataset = prepare_dataset(
        test_dataset,
        processor,
        dataset_kind=dataset_kind,
        language=args.language,
    )

    # --- Baseline metric ---
    if not args.skip_eval:
        references_before, predictions_before = compute_predictions(
            model,
            processor,
            test_dataset,
            language=args.language,
            batch_size=args.batch_size,
            num_workers=args.dataloader_num_workers,
        )
        metric = evaluate.load(args.eval_metric)
        metric_before = metric.compute(
            references=references_before, predictions=predictions_before
        )
        maybe_log_wandb_predictions(
            wandb_module,
            stage="baseline",
            metric_name=args.eval_metric,
            metric_value=metric_before,
            rows=build_prediction_rows(references_before, predictions_before),
            max_rows=args.wandb_log_samples,
        )
        print(
            f"{args.eval_metric.upper()} before finetuning: {metric_before * 100:.3f}%"
        )

    # --- Freeze parameters ---
    total_lm_layers = len(model.language_model.model.layers)
    if args.train_last_n_layers > total_lm_layers:
        raise ValueError(
            f"--train-last-n-layers ({args.train_last_n_layers}) exceeds total LM layers ({total_lm_layers})"
        )
    if args.train_last_n_layers > 0:
        first_trainable_layer = total_lm_layers - args.train_last_n_layers
        print(
            "Training language_model layers "
            f"{first_trainable_layer}-{total_lm_layers - 1} "
            f"(last {args.train_last_n_layers} / {total_lm_layers})"
        )
    else:
        print("Training language_model layers: none")

    for name, param in model.named_parameters():
        cond = should_train_parameter(
            name,
            train_last_n_layers=args.train_last_n_layers,
            total_lm_layers=total_lm_layers,
        )
        print(name, "requires_grad:", cond)
        param.requires_grad = cond

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(
        f"Trainable parameters: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)"
    )

    # --- Training ---
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        remove_unused_columns=False,
        report_to="wandb" if args.use_wandb else "none",
        bf16=True,
        eval_strategy="steps",
        save_strategy="steps",
        eval_steps=0.1,
        save_steps=args.save_steps,
        save_total_limit=args.save_total_limit,
        dataloader_num_workers=args.dataloader_num_workers,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum_steps,
        num_train_epochs=args.num_epochs,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler_type=args.lr_scheduler_type,
        logging_steps=0.1,
        learning_rate=args.learning_rate,
        data_seed=42,
    )
    data_collator = GraniteCollator(processor)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        processing_class=processor,
    )
    test_cer_callback = None
    if not args.skip_eval:
        test_cer_callback = TestCerCallback(
            processor=processor,
            test_dataset=test_dataset,
            language=args.language,
            batch_size=args.batch_size,
            num_workers=args.dataloader_num_workers,
            output_dir=args.output_dir,
            wandb_module=wandb_module,
            wandb_log_samples=args.wandb_log_samples,
            save_best_test_cer_only=args.save_best_test_cer_only,
        )
        trainer.add_callback(test_cer_callback)
    trainer.train()
    if args.save_best_test_cer_only:
        keep_checkpoint_dir = None
        if test_cer_callback is not None:
            keep_checkpoint_dir = test_cer_callback.best_checkpoint_dir
        prune_checkpoints(args.output_dir, keep_checkpoint_dir)
    trainer.save_model(args.output_dir)
    processor.save_pretrained(args.output_dir)

    # --- Post-training metric ---
    if not args.skip_eval:
        references_after, predictions_after = compute_predictions(
            model,
            processor,
            test_dataset,
            language=args.language,
            batch_size=args.batch_size,
            num_workers=args.dataloader_num_workers,
        )
        metric = evaluate.load(args.eval_metric)
        metric_after = metric.compute(
            references=references_after, predictions=predictions_after
        )
        maybe_log_wandb_predictions(
            wandb_module,
            stage="final",
            metric_name=args.eval_metric,
            metric_value=metric_after,
            rows=build_prediction_rows(references_after, predictions_after),
            max_rows=args.wandb_log_samples,
        )
        print(
            f"{args.eval_metric.upper()} after finetuning:  {metric_after * 100:.3f}%"
        )
        print(
            f"{args.eval_metric.upper()} improvement:       {(metric_before - metric_after) * 100:.3f}%"
        )

    if wandb_module is not None:
        wandb_module.finish()


if __name__ == "__main__":
    main()

マニフェストJSONLのフォーマットは以下の通りです。

json
{
  "audio_path": "path/to/audio.wav",
  "text": "書き起こしテキスト",
  "duration_sec": 5.2
}

audio_pathは絶対パスまたは--audio-rootからの相対パスが使えます。

実験から得られた知見まとめ

やって良かったこと

  • lm_head + 後ろ8層の開放:最も精度改善に寄与した
  • 句読点付き訓練データの使用:後処理パイプラインの簡素化につながった
  • TestCerCallbackによるベストモデル自動保存:GPU時間を節約しつつ最良結果を確実に保持できた
  • CERベースのデータフィルタリング:低品質データの除去で安定した学習が実現した

やって微妙だったこと / 今後の課題

  • 固有名詞・専門用語の誤認識はまだ多い(訓練データにドメイン固有語を増やす必要あり)
  • n=8を超える層数は今回のデータ量では過学習気味になった
  • Qwen3-ASR超えにはさらなるデータ量増加が必要と見ている

よくある質問(FAQ)

Q. Audio EncoderはFTしなくていいのですか?

Audio Encoderはすでに多言語の音声に汎用的に適応されており、今回のような追加学習では凍結したまま進めるのが一般的です。Encoderまで開放するとパラメータ数が大幅に増え、過学習リスクと計算コストが増大します。大規模データ(数百〜数千時間)があるならEncoder FTも選択肢に入ります。

Q. どのくらいのGPUが必要ですか?

batch_size=8bfloat16、後ろ8層開放の設定で、VRAM 24GB程度(RTX 3090/4090相当)が目安です。grad_accum_stepsを増やして実効バッチサイズを確保すれば、より小さいVRAMでも動作します。

Q. HuggingFaceのデータセット(GigaSpeech等)でも同じ設定は使えますか?

使えます。ローカルマニフェストの代わりにHuggingFaceのデータセットIDを指定すれば、同じスクリプトで学習できます。ただし英語データセットでは--language enを指定し、テキスト正規化の処理を合わせてください。

Q. 日本語テキストの正規化はどうすればいいですか?

MeCabや正規化ライブラリを使い、全角・半角の統一、数字の表記統一、句読点の正規化を行うと評価の安定性が増します。スクリプト内のbuild_eval_normalizer関数を日本語対応に書き換えてください。

Contact

仕事の依頼などのお問い合わせはこちら

新規プロジェクトのご相談、開発のご依頼、協業のご相談などがあれば、お気軽にご連絡ください。

関連するブログ

この記事に近いテーマのブログをピックアップしています。