How I Finetuned IBM Granite Speech 1B on Japanese Audio and Improved CER from 0.37 to 0.14

Summary: I finetuned IBM Granite Speech (granite-4.0-1b-speech) on 100 hours of Japanese speech data and reduced CER from 0.37 to 0.14. The official script's Projector+LoRA-only training has a ceiling on accuracy gains. The key breakthrough was additionally training lm_head and the last 8 layers of the Language Model. The result matches Qwen3-ASR-1.7B (CER 0.14) with only 1B parameters.

Why Finetune Granite Speech?

From late 2025 into 2026, competition in Japanese ASR (Automatic Speech Recognition) has intensified rapidly. While strong models like Qwen3-ASR and ReazonSpeech are available, IBM's granite-4.0-1b-speech was primarily trained on English, meaning it doesn't perform well on Japanese out of the box.

That said, Granite Speech has a relatively clean architecture with two well-defined finetuning entry points: the Projector (the bridge layer connecting the audio encoder to the LLM) and LoRA (adapters inserted into the Language Model). IBM also provides an official Colab notebook, making it an accessible starting point.

However, sticking to the official configuration leaves performance on the table. This post walks through the parameter expansion strategy I used to break through that ceiling, along with practical lessons from training on 100 hours of Japanese speech.

Background: Granite Speech Architecture

granite-4.0-1b-speech is built around three components:

ComponentRole
Speech EncoderConverts raw audio waveforms into high-resolution acoustic representations (Conformer + CTC)
Speech Projector (Q-Former)Compresses and extracts meaning from audio representations, then maps them into the LLM's embedding space
Language ModelHandles text generation and semantic understanding

LoRA (Low-Rank Adaptation) inserts small low-rank matrices into each Attention layer of the Language Model, allowing efficient finetuning with few parameters while keeping the original weights frozen.

CER (Character Error Rate) measures the proportion of incorrectly recognized characters. Because Japanese word segmentation is non-trivial, CER is the standard evaluation metric rather than WER (Word Error Rate). Lower is better — a CER of 0.14 means 14% of characters are misrecognized.

The Limits of Official Finetuning

The official script keeps the trainable parameter set minimal:

python
for n, p in model.named_parameters():
    # Only train Projector and LoRA layers
    p.requires_grad = "projector" in n or "lora" in n

https://colab.research.google.com/github/ibm-granite/granite-speech-models/blob/main/notebooks/fine_tuning_granite_speech.ipynb

This works — CER does improve early on. But accuracy gains plateau quickly even as data volume grows.

The intuition is straightforward: Projector+LoRA handles the "bridge" between audio and text, but the Language Model's core understanding of Japanese vocabulary, grammar, and context is never updated. Even with 100 hours of data, the LM's limited Japanese capability becomes the bottleneck.

This is especially true for kanji, proper nouns, and punctuation — all of which depend heavily on what the Language Model already knows.

The Fix: Expanding the Set of Trainable Parameters

Adding lm_head and the Last N Transformer Layers

The strategy I adopted:

  1. Add lm_head (the final token prediction layer) to the trainable set
  2. Unfreeze the last N Transformer layers of the Language Model
python
def should_train_parameter(
    name: str, *, train_last_n_layers: int, total_lm_layers: int
) -> bool:
    # Projector / LoRA / lm_head are always trained
    if "projector" in name or "lora" in name or "lm_head" in name:
        return True

    if train_last_n_layers <= 0:
        return False

    # Parse names of the form: language_model.model.layers.{index}.*
    layer_prefix = "language_model.model.layers."
    if not name.startswith(layer_prefix):
        return False

    remainder = name[len(layer_prefix):]
    layer_index_str = remainder.split(".", 1)[0]
    if not layer_index_str.isdigit():
        return False

    layer_index = int(layer_index_str)
    first_trainable_layer = max(total_lm_layers - train_last_n_layers, 0)
    return layer_index >= first_trainable_layer

Apply this after loading the model:

python
total_lm_layers = len(model.language_model.model.layers)

for name, param in model.named_parameters():
    param.requires_grad = should_train_parameter(
        name,
        train_last_n_layers=args.train_last_n_layers,
        total_lm_layers=total_lm_layers,
    )

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")

Choosing the Right Number of Layers

The optimal N depends on how much data you have. Unfreezing too many layers with limited data increases overfitting risk.

Data volumeRecommended N layers
Up to 10 hours0–2 (official defaults or minimal)
10–50 hours2–4
50–100 hours4–8
100+ hours8 to full unfreezing

With 100 hours of Japanese speech, n=8 produced the best results.

Training Hyperparameters

python
batch_size = 8
learning_rate = 5e-5
train_last_n_layers = 8
num_epochs = 3.0
warmup_ratio = 0.0
lr_scheduler_type = "cosine"
grad_accum_steps = 2

Why Punctuation Output Matters

One underappreciated benefit of this finetuning run is that the model learned to output punctuation.

The base granite-4.0-1b-speech tends to produce unpunctuated text. By training on transcripts that include punctuation, the finetuned model now outputs commas and periods naturally.

In production ASR pipelines, punctuation-free output often requires a separate post-processing step to re-insert it. This improvement eliminates that step entirely, meaningfully simplifying the pipeline.

Benchmark Results

Evaluation used the same internal dataset referenced in this Japanese ASR benchmark post.

ModelCERNotes
ibm-granite/granite-4.0-1b-speech (baseline)0.37Before finetuning
ibm-granite/granite-4.0-1b-speech (finetuned)0.141B parameters, matched top performance
qwen/qwen3-asr-1.7b0.14Reference model (1.7B parameters)

Qwen3-ASR-1.7B has 1.7 billion parameters. Matching its CER with a 1B parameter model demonstrates the effectiveness of targeted finetuning.

Remaining Challenges

  • Kanji and proper nouns: Person names, place names, and domain-specific terminology still show elevated error rates
  • Beating Qwen: More data could push past it, but that threshold hasn't been crossed yet
  • Occasional errors at utterance boundaries and with long vowels

Data Quality Filtering

With 100 hours of data, quality control directly affects the outcome. The script supports filtering manifest files by CER, audio duration, and transcript length:

python
# Example filtering configuration (local manifest)
--min-duration-sec 10.0    # Exclude clips shorter than 10 seconds
--max-duration-sec 30.0    # Exclude clips longer than 30 seconds
--min-text-len 45          # Exclude samples with very short transcripts
--max-cer 0.4              # Exclude samples where Whisper CER exceeds 40%

Finetuning accuracy is strongly data-quality dependent. Noisy audio and transcripts with many errors should be filtered out. A practical approach: run ASR inference on your raw data to compute CER, then remove outliers before training.

Checkpoint Management with TestCerCallback

The script includes a TestCerCallback that evaluates test CER at each checkpoint and optionally keeps only the best one — saving both storage and GPU time.

python
trainer.add_callback(TestCerCallback(
    processor=processor,
    test_dataset=test_dataset,
    language="ja",
    save_best_test_cer_only=True,  # Keep only the best checkpoint
    ...
))

With --save-best-test-cer-only enabled, each checkpoint that doesn't improve test CER is automatically deleted.

Full Finetuning Script

Running this script requires the main branch of HuggingFace Transformers, as the granite_speech module may not yet be in the stable release.

bash
pip install git+https://github.com/huggingface/transformers.git
pip install -U datasets peft accelerate evaluate jiwer soundfile tqdm
python
"""
Finetuning script for IBM Granite Speech models.

Based on:
https://colab.research.google.com/github/ibm-granite/granite-speech-models/blob/main/notebooks/fine_tuning_granite_speech.ipynb

Requirements:
    pip install git+https://github.com/huggingface/transformers.git
    pip install -U datasets peft accelerate evaluate jiwer soundfile tqdm
"""

import argparse
import json
import shutil
from pathlib import Path

import jiwer
import torch
import tqdm
import evaluate
import soundfile as sf
from datasets import load_dataset, concatenate_datasets, Audio
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer, TrainerCallback
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.granite_speech import (
    GraniteSpeechForConditionalGeneration,
    GraniteSpeechProcessor,
)


# ---------------------------------------------------------------------------
# Data preprocessing
# ---------------------------------------------------------------------------


def process_gigaspeech_transcript(text: str) -> str:
    text = text.replace(" <COMMA>", ",")
    text = text.replace(" <PERIOD>", ".")
    text = text.replace(" <QUESTIONMARK>", "?")
    text = text.replace(" <EXCLAMATIONPOINT>", "!")
    return text.lower()


def resolve_audio_path(
    audio_path: str, manifest_path: Path, audio_root: Path | None
) -> str:
    audio = Path(audio_path)
    if audio.is_absolute():
        return str(audio)
    if audio_root is not None:
        return str((audio_root / audio).resolve())
    return str((manifest_path.parent / audio).resolve())


def resolve_row_audio_path(
    row: dict, audio_column: str, manifest_path: Path, audio_root: Path | None
) -> str:
    if audio_column not in row:
        raise ValueError(f"{manifest_path} must contain '{audio_column}' column")
    return resolve_audio_path(str(row[audio_column]), manifest_path, audio_root)


def load_cer_results_index(results_jsonl: str | None) -> dict[str, float]:
    if not results_jsonl:
        return {}

    index: dict[str, float] = {}
    results_path = Path(results_jsonl).resolve()
    with results_path.open() as fh:
        for line_no, line in enumerate(fh, start=1):
            line = line.strip()
            if not line:
                continue
            row = json.loads(line)
            audio_path = row.get("audio")
            cer = row.get("cer")
            if not audio_path or cer is None:
                continue
            index[str(Path(audio_path).resolve())] = float(cer)

    print(f"Loaded {len(index)} CER rows from {results_path}")
    return index


def normalize_transcript(text: str, *, dataset_kind: str) -> str:
    normalized = str(text).strip()
    if dataset_kind == "hf":
        return process_gigaspeech_transcript(normalized)
    return normalized


def prep_example(example, tokenizer, dataset_kind: str):
    instruction = "Please transcribe the following audio to text<|audio|>"
    chat = [{"role": "user", "content": instruction}]
    example["prompt"] = tokenizer.apply_chat_template(
        chat,
        add_generation_prompt=True,
        tokenize=False,
    )
    example["text"] = normalize_transcript(
        example["text"],
        dataset_kind=dataset_kind,
    )
    return example


def should_keep_local_record(
    example,
    *,
    min_cer,
    max_cer,
    results_min_cer,
    results_max_cer,
    min_cer_percent,
    max_cer_percent,
    min_duration_sec,
    max_duration_sec,
    min_text_len,
    max_text_len,
    manifest_path,
    audio_root,
    cer_results_index,
):
    manifest_cer = example.get("cer")
    if min_cer is not None and (manifest_cer is None or float(manifest_cer) < min_cer):
        return False
    if max_cer is not None and (manifest_cer is None or float(manifest_cer) > max_cer):
        return False

    results_cer = None
    if cer_results_index:
        audio_path = resolve_audio_path(
            str(example["audio_path"]),
            manifest_path,
            audio_root,
        )
        results_cer = cer_results_index.get(audio_path)
    if results_min_cer is not None and (
        results_cer is None or float(results_cer) < results_min_cer
    ):
        return False
    if results_max_cer is not None and (
        results_cer is None or float(results_cer) > results_max_cer
    ):
        return False

    cer_percent = example.get("cer_percent")
    if min_cer_percent is not None and (
        cer_percent is None or float(cer_percent) < min_cer_percent
    ):
        return False
    if max_cer_percent is not None and (
        cer_percent is None or float(cer_percent) > max_cer_percent
    ):
        return False

    duration_sec = example.get("duration_sec")
    if min_duration_sec is not None and (
        duration_sec is None or float(duration_sec) < min_duration_sec
    ):
        return False
    if max_duration_sec is not None and (
        duration_sec is None or float(duration_sec) > max_duration_sec
    ):
        return False

    text = str(example.get("text", "")).strip()
    if min_text_len is not None and len(text) < min_text_len:
        return False
    if max_text_len is not None and len(text) > max_text_len:
        return False

    return True


def prepare_dataset(ds, processor, *, dataset_kind: str, language: str):
    columns_to_remove = [col for col in ds.column_names if col not in ["audio", "text"]]
    ds = ds.cast_column(
        "audio",
        Audio(
            sampling_rate=processor.audio_processor.sampling_rate,
            decode=False,
        ),
    )
    ds = ds.map(
        prep_example,
        fn_kwargs={
            "tokenizer": processor.tokenizer,
            "dataset_kind": dataset_kind,
        },
        remove_columns=columns_to_remove,
    )
    if dataset_kind == "hf":
        ds = ds.filter(
            lambda x: x["text"] not in ["<other>", "<noise>", "<music>", "<sil>"]
        )
    return ds


# ---------------------------------------------------------------------------
# Collator
# ---------------------------------------------------------------------------


def _extract_audio_array(audio):
    """Load audio without relying on datasets' torchcodec-based decoding."""
    if hasattr(audio, "get_all_samples"):
        samples = audio.get_all_samples()
        return samples.data.squeeze(0).numpy()
    if isinstance(audio, dict):
        if "array" in audio:
            return audio["array"]
        if audio.get("path"):
            audio_array, sampling_rate = sf.read(audio["path"], dtype="float32")
            if audio_array.ndim > 1:
                audio_array = audio_array.mean(axis=1)
            return audio_array, sampling_rate
    return audio


class GraniteCollator:
    def __init__(self, processor, inference_mode: bool = False):
        self.processor = processor
        self.inference_mode = inference_mode
        self.sampling_rate = processor.audio_processor.sampling_rate

    def _prepare_audio(self, audio):
        extracted = _extract_audio_array(audio)
        sampling_rate = self.sampling_rate

        if isinstance(extracted, tuple):
            audio_array, sampling_rate = extracted
        else:
            audio_array = extracted

        audio_tensor = torch.as_tensor(audio_array, dtype=torch.float32)
        if audio_tensor.ndim > 1:
            audio_tensor = audio_tensor.mean(dim=0)
        if sampling_rate != self.sampling_rate:
            audio_tensor = (
                torch.nn.functional.interpolate(
                    audio_tensor.unsqueeze(0).unsqueeze(0),
                    size=int(
                        audio_tensor.shape[-1] * self.sampling_rate / sampling_rate
                    ),
                    mode="linear",
                    align_corners=False,
                )
                .squeeze(0)
                .squeeze(0)
            )
        return audio_tensor.numpy()

    def __call__(self, examples):
        prompts = [ex["prompt"] for ex in examples]
        audios = [self._prepare_audio(ex["audio"]) for ex in examples]

        processed = self.processor(
            prompts, audios, return_tensors="pt", padding=True, padding_side="left"
        )
        input_ids = processed.input_ids
        attention_mask = processed.attention_mask
        labels = None

        if not self.inference_mode:
            targets = [
                ex["text"] + self.processor.tokenizer.eos_token for ex in examples
            ]
            targets = self.processor.tokenizer(
                targets, return_tensors="pt", padding=True, padding_side="right"
            )
            input_ids = torch.cat([input_ids, targets.input_ids], dim=1)
            attention_mask = torch.cat([attention_mask, targets.attention_mask], dim=1)
            labels = targets.input_ids.clone()
            labels[~targets.attention_mask.bool()] = -100
            labels = torch.cat(
                [torch.full_like(processed.input_ids, -100), labels], dim=1
            )

        return BatchFeature(
            data={
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "input_features": processed.input_features,
                "input_features_mask": processed.input_features_mask,
            }
        )


# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------


def build_eval_normalizer(language: str):
    if language == "ja":
        # Implement Japanese-specific normalization (e.g., MeCab-based)
        return lambda x: str(x).strip()
    try:
        from whisper.normalizers import EnglishTextNormalizer
        return EnglishTextNormalizer()
    except ModuleNotFoundError:
        return lambda x: str(x).strip().lower()


def compute_predictions(
    model,
    processor,
    dataset,
    *,
    language: str,
    batch_size: int = 16,
    num_workers: int = 8,
):
    collator = GraniteCollator(processor, inference_mode=True)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, collate_fn=collator, num_workers=num_workers
    )
    normalizer = build_eval_normalizer(language)

    model = model.eval().cuda()
    predictions = []

    for batch in tqdm.tqdm(dataloader, desc="Running inference"):
        batch = batch.to("cuda")
        with torch.inference_mode(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
            outputs = model.generate(
                **batch, max_new_tokens=400, num_beams=4, early_stopping=True
            )
        input_length = batch.input_ids.shape[1]
        outputs = outputs[:, input_length:].cpu()
        for x in outputs:
            predictions.append(processor.tokenizer.decode(x, skip_special_tokens=True))

    references = [normalizer(x) for x in dataset["text"]]
    predictions = [normalizer(x) for x in predictions]
    return references, predictions


def build_diff_string(reference: str, prediction: str) -> str:
    import difflib

    matcher = difflib.SequenceMatcher(None, prediction, reference)
    result = []
    for tag, i1, i2, j1, j2 in matcher.get_opcodes():
        if tag == "equal":
            result.append(prediction[i1:i2])
        else:
            result.append(f"(w: {prediction[i1:i2]}|q: {reference[j1:j2]})")
    return "".join(result)


def build_prediction_rows(references, predictions):
    rows = []
    for index, (reference, prediction) in enumerate(
        zip(references, predictions), start=1
    ):
        rows.append(
            {
                "index": index,
                "reference": reference,
                "prediction": prediction,
                "cer": jiwer.cer(reference, prediction),
                "diff": build_diff_string(reference, prediction),
            }
        )
    return rows


def maybe_init_wandb(args):
    if not args.use_wandb:
        return None
    import wandb

    wandb.init(
        project=args.wandb_project,
        entity=args.wandb_entity,
        name=args.wandb_run_name,
        config=vars(args),
    )
    return wandb


def maybe_log_wandb_predictions(
    wandb_module,
    *,
    stage: str,
    metric_name: str,
    metric_value: float,
    rows,
    max_rows: int,
):
    if wandb_module is None:
        return
    table = wandb_module.Table(
        columns=["index", "reference", "prediction", "cer", "diff"]
    )
    for row in rows[:max_rows]:
        table.add_data(
            row["index"], row["reference"], row["prediction"], row["cer"], row["diff"]
        )
    wandb_module.log(
        {
            f"{stage}/{metric_name}": metric_value,
            f"{stage}/prediction_table": table,
        }
    )


def prune_checkpoints(output_dir: str, keep_checkpoint_dir: Path | None) -> None:
    keep_path = (
        keep_checkpoint_dir.resolve() if keep_checkpoint_dir is not None else None
    )
    for checkpoint_dir in Path(output_dir).glob("checkpoint-*"):
        if not checkpoint_dir.is_dir():
            continue
        if keep_path is not None and checkpoint_dir.resolve() == keep_path:
            continue
        shutil.rmtree(checkpoint_dir)


class TestCerCallback(TrainerCallback):
    def __init__(
        self,
        *,
        processor,
        test_dataset,
        language: str,
        batch_size: int,
        num_workers: int,
        output_dir: str,
        wandb_module,
        wandb_log_samples: int,
        save_best_test_cer_only: bool,
    ):
        self.processor = processor
        self.test_dataset = test_dataset
        self.language = language
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.output_dir = output_dir
        self.wandb_module = wandb_module
        self.wandb_log_samples = wandb_log_samples
        self.save_best_test_cer_only = save_best_test_cer_only
        self.best_test_cer: float | None = None
        self.best_checkpoint_dir: Path | None = None

    def on_save(self, args, state, control, model=None, **kwargs):
        references, predictions = compute_predictions(
            model,
            self.processor,
            self.test_dataset,
            language=self.language,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )
        test_cer = jiwer.cer(references, predictions)
        prediction_rows = build_prediction_rows(references, predictions)

        checkpoint_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}"
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        metrics = {
            "global_step": state.global_step,
            "test_cer": test_cer,
            "num_samples": len(references),
        }
        (checkpoint_dir / "test_cer.json").write_text(
            json.dumps(metrics, ensure_ascii=False, indent=2) + "\n",
            encoding="utf-8",
        )
        with (checkpoint_dir / "test_predictions.jsonl").open(
            "w", encoding="utf-8"
        ) as f:
            for row in prediction_rows:
                f.write(json.dumps(row, ensure_ascii=False) + "\n")
        with (Path(self.output_dir) / "test_cer_history.jsonl").open(
            "a", encoding="utf-8"
        ) as f:
            f.write(json.dumps(metrics, ensure_ascii=False) + "\n")

        is_best = self.best_test_cer is None or test_cer < self.best_test_cer
        if is_best:
            previous_best_dir = self.best_checkpoint_dir
            self.best_test_cer = test_cer
            self.best_checkpoint_dir = checkpoint_dir
            (Path(self.output_dir) / "best_test_cer.json").write_text(
                json.dumps(
                    {
                        "global_step": state.global_step,
                        "test_cer": test_cer,
                        "checkpoint_dir": str(checkpoint_dir),
                    },
                    ensure_ascii=False,
                    indent=2,
                )
                + "\n",
                encoding="utf-8",
            )
            if (
                self.save_best_test_cer_only
                and previous_best_dir is not None
                and previous_best_dir != checkpoint_dir
                and previous_best_dir.exists()
            ):
                shutil.rmtree(previous_best_dir)
        elif self.save_best_test_cer_only and checkpoint_dir.exists():
            shutil.rmtree(checkpoint_dir)

        maybe_log_wandb_predictions(
            self.wandb_module,
            stage=f"checkpoint_{state.global_step}",
            metric_name="cer",
            metric_value=test_cer,
            rows=prediction_rows,
            max_rows=self.wandb_log_samples,
        )
        if is_best:
            print(f"Test CER at step {state.global_step}: {test_cer * 100:.3f}% (best)")
        else:
            print(f"Test CER at step {state.global_step}: {test_cer * 100:.3f}%")
        return control


def take_first_n(ds, n: int):
    if n is None or n <= 0:
        return ds
    return ds.select(range(min(n, len(ds))))


def _load_single_manifest(manifest_path: Path, audio_root: Path | None, args):
    cer_results_index = getattr(args, "cer_results_index", {})
    dataset = load_dataset("json", data_files={"manifest": str(manifest_path)})[
        "manifest"
    ]
    before_count = len(dataset)
    dataset = dataset.filter(
        should_keep_local_record,
        fn_kwargs={
            "min_cer": args.min_cer,
            "max_cer": args.max_cer,
            "results_min_cer": args.results_min_cer,
            "results_max_cer": args.results_max_cer,
            "min_cer_percent": args.min_cer_percent,
            "max_cer_percent": args.max_cer_percent,
            "min_duration_sec": args.min_duration_sec,
            "max_duration_sec": args.max_duration_sec,
            "min_text_len": args.min_text_len,
            "max_text_len": args.max_text_len,
            "manifest_path": manifest_path,
            "audio_root": audio_root,
            "cer_results_index": cer_results_index,
        },
    )
    if cer_results_index and (
        args.results_min_cer is not None or args.results_max_cer is not None
    ):
        removed_count = before_count - len(dataset)
        print(
            f"Filtered {removed_count} rows from {manifest_path} using external CER thresholds"
        )
    dataset = dataset.map(
        lambda example: {
            "audio": resolve_audio_path(
                example["audio_path"], manifest_path, audio_root
            ),
            "text": str(example["text"]).strip(),
        },
        remove_columns=[
            col for col in dataset.column_names if col not in ["audio_path", "text"]
        ],
    )
    return dataset


def load_local_dataset(args):
    local_manifest = Path(args.local_manifest).resolve()
    audio_root = Path(args.audio_root).resolve() if args.audio_root else None
    args.cer_results_index = load_cer_results_index(args.cer_results_jsonl)

    if local_manifest.is_dir():
        manifest_paths = sorted(local_manifest.rglob("manifest.jsonl"))
        if not manifest_paths:
            raise ValueError(f"No manifest.jsonl files found under {local_manifest}")
        non_empty_manifest_paths = [p for p in manifest_paths if p.stat().st_size > 0]
        skipped_manifest_count = len(manifest_paths) - len(non_empty_manifest_paths)
        print(f"Found {len(manifest_paths)} manifest files under {local_manifest}")
        if skipped_manifest_count:
            print(f"Skipping {skipped_manifest_count} empty manifest files")
        if not non_empty_manifest_paths:
            raise ValueError(
                f"All manifest.jsonl files under {local_manifest} are empty"
            )
        dataset = concatenate_datasets(
            [
                _load_single_manifest(p, audio_root, args)
                for p in non_empty_manifest_paths
            ]
        )
    else:
        if local_manifest.stat().st_size == 0:
            raise ValueError(f"Manifest file is empty: {local_manifest}")
        dataset = _load_single_manifest(local_manifest, audio_root, args)

    if len(dataset) < 3:
        raise ValueError(
            f"Local dataset is too small after filtering: {len(dataset)} rows. Need at least 3."
        )

    split_seed = args.split_seed
    test_size = args.test_split_ratio
    val_size = args.val_split_ratio / (1.0 - test_size)
    first_split = dataset.train_test_split(test_size=test_size, seed=split_seed)
    train_and_val = first_split["train"]
    test_dataset = first_split["test"]
    second_split = train_and_val.train_test_split(test_size=val_size, seed=split_seed)
    train_dataset = second_split["train"]
    val_dataset = second_split["test"]

    train_dataset = take_first_n(train_dataset, args.train_size)
    val_dataset = take_first_n(val_dataset, args.val_size)
    test_dataset = take_first_n(test_dataset, args.test_size)
    return train_dataset, val_dataset, test_dataset


def load_eval_json_dataset(
    path: str,
    *,
    audio_root: str | None,
    audio_column: str,
    text_column: str,
    size_limit: int | None,
):
    manifest_path = Path(path).resolve()
    resolved_audio_root = Path(audio_root).resolve() if audio_root else None

    dataset = load_dataset("json", data_files={"eval": str(manifest_path)})["eval"]
    dataset = dataset.map(
        lambda example: {
            "audio": resolve_row_audio_path(
                example,
                audio_column=audio_column,
                manifest_path=manifest_path,
                audio_root=resolved_audio_root,
            ),
            "text": str(example[text_column]).strip(),
        },
        remove_columns=[
            col
            for col in dataset.column_names
            if col not in [audio_column, text_column]
        ],
    )
    return take_first_n(dataset, size_limit)


def load_training_datasets(args):
    if args.local_manifest:
        print(f"Loading local manifest: {args.local_manifest}")
        return load_local_dataset(args), "local"

    print(f"Loading dataset: {args.dataset} ({args.dataset_subset})")
    dataset = load_dataset(args.dataset, args.dataset_subset)
    train_dataset = take_first_n(dataset["train"], args.train_size)
    val_dataset = take_first_n(dataset["validation"], args.val_size)
    test_dataset = take_first_n(dataset["test"], args.test_size)
    return (train_dataset, val_dataset, test_dataset), "hf"


# ---------------------------------------------------------------------------
# Training parameter selection
# ---------------------------------------------------------------------------


def should_train_parameter(
    name: str, *, train_last_n_layers: int, total_lm_layers: int
) -> bool:
    if "projector" in name or "lora" in name or "lm_head" in name:
        return True

    if train_last_n_layers <= 0:
        return False

    layer_prefix = "language_model.model.layers."
    if not name.startswith(layer_prefix):
        return False

    remainder = name[len(layer_prefix):]
    layer_index_str = remainder.split(".", 1)[0]
    if not layer_index_str.isdigit():
        return False

    layer_index = int(layer_index_str)
    first_trainable_layer = max(total_lm_layers - train_last_n_layers, 0)
    return layer_index >= first_trainable_layer


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def parse_args():
    parser = argparse.ArgumentParser(description="Finetune IBM Granite Speech")
    parser.add_argument("--model-name", default="ibm-granite/granite-4.0-1b-speech")
    parser.add_argument("--dataset", default="speechcolab/gigaspeech")
    parser.add_argument("--dataset-subset", default="xs")
    parser.add_argument("--local-manifest", default=None)
    parser.add_argument("--audio-root", default=None)
    parser.add_argument("--validation-json", default=None)
    parser.add_argument("--test-json", default=None)
    parser.add_argument("--eval-audio-root", default=None)
    parser.add_argument("--eval-audio-column", default="audio")
    parser.add_argument("--eval-text-column", default="text")
    parser.add_argument("--train-size", type=int, default=5000)
    parser.add_argument("--val-size", type=int, default=200)
    parser.add_argument("--test-size", type=int, default=200)
    parser.add_argument("--val-split-ratio", type=float, default=0.01)
    parser.add_argument("--test-split-ratio", type=float, default=0.01)
    parser.add_argument("--split-seed", type=int, default=42)
    parser.add_argument("--language", default=None, choices=["en", "ja"])
    parser.add_argument("--eval-metric", default=None, choices=["wer", "cer"])
    parser.add_argument("--min-cer", type=float, default=None)
    parser.add_argument("--max-cer", type=float, default=None)
    parser.add_argument("--results-min-cer", type=float, default=None)
    parser.add_argument("--results-max-cer", type=float, default=None)
    parser.add_argument("--cer-results-jsonl", default=None)
    parser.add_argument("--min-cer-percent", type=float, default=None)
    parser.add_argument("--max-cer-percent", type=float, default=None)
    parser.add_argument("--min-duration-sec", type=float, default=None)
    parser.add_argument("--max-duration-sec", type=float, default=None)
    parser.add_argument("--min-text-len", type=int, default=None)
    parser.add_argument("--max-text-len", type=int, default=None)
    parser.add_argument("--output-dir", default="save_dir")
    parser.add_argument("--use-wandb", action="store_true")
    parser.add_argument("--wandb-project", default="granite-asr")
    parser.add_argument("--wandb-entity", default=None)
    parser.add_argument("--wandb-run-name", default=None)
    parser.add_argument("--wandb-log-samples", type=int, default=50)
    parser.add_argument("--num-epochs", type=float, default=1.0)
    parser.add_argument("--batch-size", type=int, default=8)
    parser.add_argument("--grad-accum-steps", type=int, default=2)
    parser.add_argument("--learning-rate", type=float, default=3e-5)
    parser.add_argument("--warmup-ratio", type=float, default=0.2)
    parser.add_argument("--lr-scheduler-type", type=str, default="cosine")
    parser.add_argument("--train-last-n-layers", type=int, default=0)
    parser.add_argument("--dataloader-num-workers", type=int, default=16)
    parser.add_argument("--save-steps", dest="save_steps", type=int, default=200)
    parser.add_argument("--save-total-limit", dest="save_total_limit", type=int, default=None)
    parser.add_argument("--save-best-test-cer-only", action="store_true")
    parser.add_argument("--skip-eval", action="store_true")
    return parser.parse_args()


def main():
    args = parse_args()
    if args.train_last_n_layers < 0:
        raise ValueError("--train-last-n-layers must be >= 0")
    if args.save_best_test_cer_only and args.skip_eval:
        raise ValueError("--save-best-test-cer-only requires evaluation; remove --skip-eval")
    if args.language is None:
        args.language = "ja" if args.local_manifest else "en"
    wandb_module = maybe_init_wandb(args)

    if args.local_manifest:
        if not (0.0 < args.test_split_ratio < 1.0):
            raise ValueError("--test-split-ratio must be between 0 and 1")
        if not (0.0 < args.val_split_ratio < 1.0):
            raise ValueError("--val-split-ratio must be between 0 and 1")
        if args.val_split_ratio + args.test_split_ratio >= 1.0:
            raise ValueError("--val-split-ratio + --test-split-ratio must be < 1")

    if args.eval_metric is None:
        args.eval_metric = "cer" if args.language == "ja" else "wer"

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    (output_dir / "args.json").write_text(json.dumps(vars(args), indent=2, default=str))

    (train_dataset, val_dataset, test_dataset), dataset_kind = load_training_datasets(args)

    if args.validation_json:
        val_dataset = load_eval_json_dataset(
            args.validation_json,
            audio_root=args.eval_audio_root,
            audio_column=args.eval_audio_column,
            text_column=args.eval_text_column,
            size_limit=args.val_size,
        )
    if args.validation_json or args.test_json:
        test_json_path = args.test_json or args.validation_json
        test_dataset = load_eval_json_dataset(
            test_json_path,
            audio_root=args.eval_audio_root,
            audio_column=args.eval_audio_column,
            text_column=args.eval_text_column,
            size_limit=args.test_size,
        )

    print(f"Dataset sizes -> train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")

    processor = GraniteSpeechProcessor.from_pretrained(args.model_name)
    model = GraniteSpeechForConditionalGeneration.from_pretrained(
        args.model_name, dtype=torch.bfloat16
    )

    train_dataset = prepare_dataset(train_dataset, processor, dataset_kind=dataset_kind, language=args.language)
    val_dataset = prepare_dataset(val_dataset, processor, dataset_kind=dataset_kind, language=args.language)
    test_dataset = prepare_dataset(test_dataset, processor, dataset_kind=dataset_kind, language=args.language)

    if not args.skip_eval:
        references_before, predictions_before = compute_predictions(
            model, processor, test_dataset,
            language=args.language, batch_size=args.batch_size, num_workers=args.dataloader_num_workers,
        )
        metric = evaluate.load(args.eval_metric)
        metric_before = metric.compute(references=references_before, predictions=predictions_before)
        print(f"{args.eval_metric.upper()} before finetuning: {metric_before * 100:.3f}%")

    total_lm_layers = len(model.language_model.model.layers)
    if args.train_last_n_layers > total_lm_layers:
        raise ValueError(f"--train-last-n-layers exceeds total LM layers ({total_lm_layers})")

    for name, param in model.named_parameters():
        param.requires_grad = should_train_parameter(
            name,
            train_last_n_layers=args.train_last_n_layers,
            total_lm_layers=total_lm_layers,
        )

    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        remove_unused_columns=False,
        report_to="wandb" if args.use_wandb else "none",
        bf16=True,
        eval_strategy="steps",
        save_strategy="steps",
        eval_steps=0.1,
        save_steps=args.save_steps,
        save_total_limit=args.save_total_limit,
        dataloader_num_workers=args.dataloader_num_workers,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum_steps,
        num_train_epochs=args.num_epochs,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler_type=args.lr_scheduler_type,
        logging_steps=0.1,
        learning_rate=args.learning_rate,
        data_seed=42,
    )

    data_collator = GraniteCollator(processor)
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
        processing_class=processor,
    )

    test_cer_callback = None
    if not args.skip_eval:
        test_cer_callback = TestCerCallback(
            processor=processor,
            test_dataset=test_dataset,
            language=args.language,
            batch_size=args.batch_size,
            num_workers=args.dataloader_num_workers,
            output_dir=args.output_dir,
            wandb_module=wandb_module,
            wandb_log_samples=args.wandb_log_samples,
            save_best_test_cer_only=args.save_best_test_cer_only,
        )
        trainer.add_callback(test_cer_callback)

    trainer.train()

    if args.save_best_test_cer_only:
        keep_checkpoint_dir = test_cer_callback.best_checkpoint_dir if test_cer_callback else None
        prune_checkpoints(args.output_dir, keep_checkpoint_dir)

    trainer.save_model(args.output_dir)
    processor.save_pretrained(args.output_dir)

    if not args.skip_eval:
        references_after, predictions_after = compute_predictions(
            model, processor, test_dataset,
            language=args.language, batch_size=args.batch_size, num_workers=args.dataloader_num_workers,
        )
        metric = evaluate.load(args.eval_metric)
        metric_after = metric.compute(references=references_after, predictions=predictions_after)
        print(f"{args.eval_metric.upper()} after finetuning:  {metric_after * 100:.3f}%")
        print(f"{args.eval_metric.upper()} improvement:       {(metric_before - metric_after) * 100:.3f}%")

    if wandb_module is not None:
        wandb_module.finish()


if __name__ == "__main__":
    main()

The manifest JSONL format expected by the script:

json
{
  "audio_path": "path/to/audio.wav",
  "text": "Transcription text here",
  "duration_sec": 5.2
}

audio_path can be an absolute path or relative to --audio-root.

Summary of Lessons Learned

What worked well

  • Unfreezing lm_head + last 8 LM layers: The single biggest driver of accuracy improvement
  • Training data with punctuation: Eliminated a post-processing step and simplified the pipeline
  • TestCerCallback for automatic best-checkpoint saving: Preserved the best result without manual monitoring
  • CER-based data filtering: Removing low-quality samples stabilized training noticeably

What didn't pan out / remaining challenges

  • Proper nouns and domain-specific terms still show high error rates — adding domain-specific training data is the next step
  • Going beyond n=8 caused overfitting with the current data volume
  • Surpassing Qwen3-ASR will likely require more training data

FAQ

Q: Does the Audio Encoder need to be finetuned?

The Audio Encoder is already well-adapted to multilingual speech in a general sense, so freezing it is standard practice for this scale of finetuning. Unfreezing it adds a large number of parameters and significantly increases overfitting risk and compute cost. At hundreds or thousands of hours of training data, encoder finetuning becomes worth considering.

Q: How much GPU VRAM is needed?

With batch_size=8, bfloat16, and the last 8 layers unfrozen, around 24GB VRAM (RTX 3090/4090 range) is a reasonable estimate. Increasing grad_accum_steps can compensate for smaller VRAM by maintaining the effective batch size.

Q: Does this approach work with HuggingFace datasets like GigaSpeech?

Yes. Instead of a local manifest, specify a HuggingFace dataset ID and it will use the same training pipeline. For English datasets, pass --language en and adjust the text normalization in build_eval_normalizer accordingly.

Q: How should Japanese text normalization be handled for evaluation?

Using a MeCab-based normalizer to unify full-width/half-width characters, standardize number representations, and normalize punctuation improves CER evaluation stability. The build_eval_normalizer function in the script is the place to plug in your normalization logic.

Contact

For project inquiries and collaboration, contact us here.

If you are considering a new project, product development, or other collaboration, please get in touch.

Related Articles

Explore more articles connected to this topic.