How I Finetuned IBM Granite Speech 1B on Japanese Audio and Improved CER from 0.37 to 0.14
Summary: I finetuned IBM Granite Speech (granite-4.0-1b-speech) on 100 hours of Japanese speech data and reduced CER from 0.37 to 0.14. The official script's Projector+LoRA-only training has a ceiling on accuracy gains. The key breakthrough was additionally training lm_head and the last 8 layers of the Language Model. The result matches Qwen3-ASR-1.7B (CER 0.14) with only 1B parameters.
Why Finetune Granite Speech?
From late 2025 into 2026, competition in Japanese ASR (Automatic Speech Recognition) has intensified rapidly. While strong models like Qwen3-ASR and ReazonSpeech are available, IBM's granite-4.0-1b-speech was primarily trained on English, meaning it doesn't perform well on Japanese out of the box.
That said, Granite Speech has a relatively clean architecture with two well-defined finetuning entry points: the Projector (the bridge layer connecting the audio encoder to the LLM) and LoRA (adapters inserted into the Language Model). IBM also provides an official Colab notebook, making it an accessible starting point.
However, sticking to the official configuration leaves performance on the table. This post walks through the parameter expansion strategy I used to break through that ceiling, along with practical lessons from training on 100 hours of Japanese speech.
Background: Granite Speech Architecture
granite-4.0-1b-speech is built around three components:
| Component | Role |
|---|---|
| Speech Encoder | Converts raw audio waveforms into high-resolution acoustic representations (Conformer + CTC) |
| Speech Projector (Q-Former) | Compresses and extracts meaning from audio representations, then maps them into the LLM's embedding space |
| Language Model | Handles text generation and semantic understanding |
LoRA (Low-Rank Adaptation) inserts small low-rank matrices into each Attention layer of the Language Model, allowing efficient finetuning with few parameters while keeping the original weights frozen.
CER (Character Error Rate) measures the proportion of incorrectly recognized characters. Because Japanese word segmentation is non-trivial, CER is the standard evaluation metric rather than WER (Word Error Rate). Lower is better — a CER of 0.14 means 14% of characters are misrecognized.
The Limits of Official Finetuning
The official script keeps the trainable parameter set minimal:
for n, p in model.named_parameters():
# Only train Projector and LoRA layers
p.requires_grad = "projector" in n or "lora" in n
This works — CER does improve early on. But accuracy gains plateau quickly even as data volume grows.
The intuition is straightforward: Projector+LoRA handles the "bridge" between audio and text, but the Language Model's core understanding of Japanese vocabulary, grammar, and context is never updated. Even with 100 hours of data, the LM's limited Japanese capability becomes the bottleneck.
This is especially true for kanji, proper nouns, and punctuation — all of which depend heavily on what the Language Model already knows.
The Fix: Expanding the Set of Trainable Parameters
Adding lm_head and the Last N Transformer Layers
The strategy I adopted:
- Add
lm_head(the final token prediction layer) to the trainable set - Unfreeze the last N Transformer layers of the Language Model
def should_train_parameter(
name: str, *, train_last_n_layers: int, total_lm_layers: int
) -> bool:
# Projector / LoRA / lm_head are always trained
if "projector" in name or "lora" in name or "lm_head" in name:
return True
if train_last_n_layers <= 0:
return False
# Parse names of the form: language_model.model.layers.{index}.*
layer_prefix = "language_model.model.layers."
if not name.startswith(layer_prefix):
return False
remainder = name[len(layer_prefix):]
layer_index_str = remainder.split(".", 1)[0]
if not layer_index_str.isdigit():
return False
layer_index = int(layer_index_str)
first_trainable_layer = max(total_lm_layers - train_last_n_layers, 0)
return layer_index >= first_trainable_layer
Apply this after loading the model:
total_lm_layers = len(model.language_model.model.layers)
for name, param in model.named_parameters():
param.requires_grad = should_train_parameter(
name,
train_last_n_layers=args.train_last_n_layers,
total_lm_layers=total_lm_layers,
)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")
Choosing the Right Number of Layers
The optimal N depends on how much data you have. Unfreezing too many layers with limited data increases overfitting risk.
| Data volume | Recommended N layers |
|---|---|
| Up to 10 hours | 0–2 (official defaults or minimal) |
| 10–50 hours | 2–4 |
| 50–100 hours | 4–8 |
| 100+ hours | 8 to full unfreezing |
With 100 hours of Japanese speech, n=8 produced the best results.
Training Hyperparameters
batch_size = 8
learning_rate = 5e-5
train_last_n_layers = 8
num_epochs = 3.0
warmup_ratio = 0.0
lr_scheduler_type = "cosine"
grad_accum_steps = 2
Why Punctuation Output Matters
One underappreciated benefit of this finetuning run is that the model learned to output punctuation.
The base granite-4.0-1b-speech tends to produce unpunctuated text. By training on transcripts that include punctuation, the finetuned model now outputs commas and periods naturally.
In production ASR pipelines, punctuation-free output often requires a separate post-processing step to re-insert it. This improvement eliminates that step entirely, meaningfully simplifying the pipeline.
Benchmark Results
Evaluation used the same internal dataset referenced in this Japanese ASR benchmark post.
| Model | CER | Notes |
|---|---|---|
ibm-granite/granite-4.0-1b-speech (baseline) | 0.37 | Before finetuning |
ibm-granite/granite-4.0-1b-speech (finetuned) | 0.14 | 1B parameters, matched top performance |
qwen/qwen3-asr-1.7b | 0.14 | Reference model (1.7B parameters) |
Qwen3-ASR-1.7B has 1.7 billion parameters. Matching its CER with a 1B parameter model demonstrates the effectiveness of targeted finetuning.
Remaining Challenges
- Kanji and proper nouns: Person names, place names, and domain-specific terminology still show elevated error rates
- Beating Qwen: More data could push past it, but that threshold hasn't been crossed yet
- Occasional errors at utterance boundaries and with long vowels
Data Quality Filtering
With 100 hours of data, quality control directly affects the outcome. The script supports filtering manifest files by CER, audio duration, and transcript length:
# Example filtering configuration (local manifest)
--min-duration-sec 10.0 # Exclude clips shorter than 10 seconds
--max-duration-sec 30.0 # Exclude clips longer than 30 seconds
--min-text-len 45 # Exclude samples with very short transcripts
--max-cer 0.4 # Exclude samples where Whisper CER exceeds 40%
Finetuning accuracy is strongly data-quality dependent. Noisy audio and transcripts with many errors should be filtered out. A practical approach: run ASR inference on your raw data to compute CER, then remove outliers before training.
Checkpoint Management with TestCerCallback
The script includes a TestCerCallback that evaluates test CER at each checkpoint and optionally keeps only the best one — saving both storage and GPU time.
trainer.add_callback(TestCerCallback(
processor=processor,
test_dataset=test_dataset,
language="ja",
save_best_test_cer_only=True, # Keep only the best checkpoint
...
))
With --save-best-test-cer-only enabled, each checkpoint that doesn't improve test CER is automatically deleted.
Full Finetuning Script
Running this script requires the main branch of HuggingFace Transformers, as the granite_speech module may not yet be in the stable release.
pip install git+https://github.com/huggingface/transformers.git
pip install -U datasets peft accelerate evaluate jiwer soundfile tqdm
"""
Finetuning script for IBM Granite Speech models.
Based on:
https://colab.research.google.com/github/ibm-granite/granite-speech-models/blob/main/notebooks/fine_tuning_granite_speech.ipynb
Requirements:
pip install git+https://github.com/huggingface/transformers.git
pip install -U datasets peft accelerate evaluate jiwer soundfile tqdm
"""
import argparse
import json
import shutil
from pathlib import Path
import jiwer
import torch
import tqdm
import evaluate
import soundfile as sf
from datasets import load_dataset, concatenate_datasets, Audio
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer, TrainerCallback
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.granite_speech import (
GraniteSpeechForConditionalGeneration,
GraniteSpeechProcessor,
)
# ---------------------------------------------------------------------------
# Data preprocessing
# ---------------------------------------------------------------------------
def process_gigaspeech_transcript(text: str) -> str:
text = text.replace(" <COMMA>", ",")
text = text.replace(" <PERIOD>", ".")
text = text.replace(" <QUESTIONMARK>", "?")
text = text.replace(" <EXCLAMATIONPOINT>", "!")
return text.lower()
def resolve_audio_path(
audio_path: str, manifest_path: Path, audio_root: Path | None
) -> str:
audio = Path(audio_path)
if audio.is_absolute():
return str(audio)
if audio_root is not None:
return str((audio_root / audio).resolve())
return str((manifest_path.parent / audio).resolve())
def resolve_row_audio_path(
row: dict, audio_column: str, manifest_path: Path, audio_root: Path | None
) -> str:
if audio_column not in row:
raise ValueError(f"{manifest_path} must contain '{audio_column}' column")
return resolve_audio_path(str(row[audio_column]), manifest_path, audio_root)
def load_cer_results_index(results_jsonl: str | None) -> dict[str, float]:
if not results_jsonl:
return {}
index: dict[str, float] = {}
results_path = Path(results_jsonl).resolve()
with results_path.open() as fh:
for line_no, line in enumerate(fh, start=1):
line = line.strip()
if not line:
continue
row = json.loads(line)
audio_path = row.get("audio")
cer = row.get("cer")
if not audio_path or cer is None:
continue
index[str(Path(audio_path).resolve())] = float(cer)
print(f"Loaded {len(index)} CER rows from {results_path}")
return index
def normalize_transcript(text: str, *, dataset_kind: str) -> str:
normalized = str(text).strip()
if dataset_kind == "hf":
return process_gigaspeech_transcript(normalized)
return normalized
def prep_example(example, tokenizer, dataset_kind: str):
instruction = "Please transcribe the following audio to text<|audio|>"
chat = [{"role": "user", "content": instruction}]
example["prompt"] = tokenizer.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=False,
)
example["text"] = normalize_transcript(
example["text"],
dataset_kind=dataset_kind,
)
return example
def should_keep_local_record(
example,
*,
min_cer,
max_cer,
results_min_cer,
results_max_cer,
min_cer_percent,
max_cer_percent,
min_duration_sec,
max_duration_sec,
min_text_len,
max_text_len,
manifest_path,
audio_root,
cer_results_index,
):
manifest_cer = example.get("cer")
if min_cer is not None and (manifest_cer is None or float(manifest_cer) < min_cer):
return False
if max_cer is not None and (manifest_cer is None or float(manifest_cer) > max_cer):
return False
results_cer = None
if cer_results_index:
audio_path = resolve_audio_path(
str(example["audio_path"]),
manifest_path,
audio_root,
)
results_cer = cer_results_index.get(audio_path)
if results_min_cer is not None and (
results_cer is None or float(results_cer) < results_min_cer
):
return False
if results_max_cer is not None and (
results_cer is None or float(results_cer) > results_max_cer
):
return False
cer_percent = example.get("cer_percent")
if min_cer_percent is not None and (
cer_percent is None or float(cer_percent) < min_cer_percent
):
return False
if max_cer_percent is not None and (
cer_percent is None or float(cer_percent) > max_cer_percent
):
return False
duration_sec = example.get("duration_sec")
if min_duration_sec is not None and (
duration_sec is None or float(duration_sec) < min_duration_sec
):
return False
if max_duration_sec is not None and (
duration_sec is None or float(duration_sec) > max_duration_sec
):
return False
text = str(example.get("text", "")).strip()
if min_text_len is not None and len(text) < min_text_len:
return False
if max_text_len is not None and len(text) > max_text_len:
return False
return True
def prepare_dataset(ds, processor, *, dataset_kind: str, language: str):
columns_to_remove = [col for col in ds.column_names if col not in ["audio", "text"]]
ds = ds.cast_column(
"audio",
Audio(
sampling_rate=processor.audio_processor.sampling_rate,
decode=False,
),
)
ds = ds.map(
prep_example,
fn_kwargs={
"tokenizer": processor.tokenizer,
"dataset_kind": dataset_kind,
},
remove_columns=columns_to_remove,
)
if dataset_kind == "hf":
ds = ds.filter(
lambda x: x["text"] not in ["<other>", "<noise>", "<music>", "<sil>"]
)
return ds
# ---------------------------------------------------------------------------
# Collator
# ---------------------------------------------------------------------------
def _extract_audio_array(audio):
"""Load audio without relying on datasets' torchcodec-based decoding."""
if hasattr(audio, "get_all_samples"):
samples = audio.get_all_samples()
return samples.data.squeeze(0).numpy()
if isinstance(audio, dict):
if "array" in audio:
return audio["array"]
if audio.get("path"):
audio_array, sampling_rate = sf.read(audio["path"], dtype="float32")
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
return audio_array, sampling_rate
return audio
class GraniteCollator:
def __init__(self, processor, inference_mode: bool = False):
self.processor = processor
self.inference_mode = inference_mode
self.sampling_rate = processor.audio_processor.sampling_rate
def _prepare_audio(self, audio):
extracted = _extract_audio_array(audio)
sampling_rate = self.sampling_rate
if isinstance(extracted, tuple):
audio_array, sampling_rate = extracted
else:
audio_array = extracted
audio_tensor = torch.as_tensor(audio_array, dtype=torch.float32)
if audio_tensor.ndim > 1:
audio_tensor = audio_tensor.mean(dim=0)
if sampling_rate != self.sampling_rate:
audio_tensor = (
torch.nn.functional.interpolate(
audio_tensor.unsqueeze(0).unsqueeze(0),
size=int(
audio_tensor.shape[-1] * self.sampling_rate / sampling_rate
),
mode="linear",
align_corners=False,
)
.squeeze(0)
.squeeze(0)
)
return audio_tensor.numpy()
def __call__(self, examples):
prompts = [ex["prompt"] for ex in examples]
audios = [self._prepare_audio(ex["audio"]) for ex in examples]
processed = self.processor(
prompts, audios, return_tensors="pt", padding=True, padding_side="left"
)
input_ids = processed.input_ids
attention_mask = processed.attention_mask
labels = None
if not self.inference_mode:
targets = [
ex["text"] + self.processor.tokenizer.eos_token for ex in examples
]
targets = self.processor.tokenizer(
targets, return_tensors="pt", padding=True, padding_side="right"
)
input_ids = torch.cat([input_ids, targets.input_ids], dim=1)
attention_mask = torch.cat([attention_mask, targets.attention_mask], dim=1)
labels = targets.input_ids.clone()
labels[~targets.attention_mask.bool()] = -100
labels = torch.cat(
[torch.full_like(processed.input_ids, -100), labels], dim=1
)
return BatchFeature(
data={
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"input_features": processed.input_features,
"input_features_mask": processed.input_features_mask,
}
)
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
def build_eval_normalizer(language: str):
if language == "ja":
# Implement Japanese-specific normalization (e.g., MeCab-based)
return lambda x: str(x).strip()
try:
from whisper.normalizers import EnglishTextNormalizer
return EnglishTextNormalizer()
except ModuleNotFoundError:
return lambda x: str(x).strip().lower()
def compute_predictions(
model,
processor,
dataset,
*,
language: str,
batch_size: int = 16,
num_workers: int = 8,
):
collator = GraniteCollator(processor, inference_mode=True)
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=collator, num_workers=num_workers
)
normalizer = build_eval_normalizer(language)
model = model.eval().cuda()
predictions = []
for batch in tqdm.tqdm(dataloader, desc="Running inference"):
batch = batch.to("cuda")
with torch.inference_mode(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
outputs = model.generate(
**batch, max_new_tokens=400, num_beams=4, early_stopping=True
)
input_length = batch.input_ids.shape[1]
outputs = outputs[:, input_length:].cpu()
for x in outputs:
predictions.append(processor.tokenizer.decode(x, skip_special_tokens=True))
references = [normalizer(x) for x in dataset["text"]]
predictions = [normalizer(x) for x in predictions]
return references, predictions
def build_diff_string(reference: str, prediction: str) -> str:
import difflib
matcher = difflib.SequenceMatcher(None, prediction, reference)
result = []
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal":
result.append(prediction[i1:i2])
else:
result.append(f"(w: {prediction[i1:i2]}|q: {reference[j1:j2]})")
return "".join(result)
def build_prediction_rows(references, predictions):
rows = []
for index, (reference, prediction) in enumerate(
zip(references, predictions), start=1
):
rows.append(
{
"index": index,
"reference": reference,
"prediction": prediction,
"cer": jiwer.cer(reference, prediction),
"diff": build_diff_string(reference, prediction),
}
)
return rows
def maybe_init_wandb(args):
if not args.use_wandb:
return None
import wandb
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.wandb_run_name,
config=vars(args),
)
return wandb
def maybe_log_wandb_predictions(
wandb_module,
*,
stage: str,
metric_name: str,
metric_value: float,
rows,
max_rows: int,
):
if wandb_module is None:
return
table = wandb_module.Table(
columns=["index", "reference", "prediction", "cer", "diff"]
)
for row in rows[:max_rows]:
table.add_data(
row["index"], row["reference"], row["prediction"], row["cer"], row["diff"]
)
wandb_module.log(
{
f"{stage}/{metric_name}": metric_value,
f"{stage}/prediction_table": table,
}
)
def prune_checkpoints(output_dir: str, keep_checkpoint_dir: Path | None) -> None:
keep_path = (
keep_checkpoint_dir.resolve() if keep_checkpoint_dir is not None else None
)
for checkpoint_dir in Path(output_dir).glob("checkpoint-*"):
if not checkpoint_dir.is_dir():
continue
if keep_path is not None and checkpoint_dir.resolve() == keep_path:
continue
shutil.rmtree(checkpoint_dir)
class TestCerCallback(TrainerCallback):
def __init__(
self,
*,
processor,
test_dataset,
language: str,
batch_size: int,
num_workers: int,
output_dir: str,
wandb_module,
wandb_log_samples: int,
save_best_test_cer_only: bool,
):
self.processor = processor
self.test_dataset = test_dataset
self.language = language
self.batch_size = batch_size
self.num_workers = num_workers
self.output_dir = output_dir
self.wandb_module = wandb_module
self.wandb_log_samples = wandb_log_samples
self.save_best_test_cer_only = save_best_test_cer_only
self.best_test_cer: float | None = None
self.best_checkpoint_dir: Path | None = None
def on_save(self, args, state, control, model=None, **kwargs):
references, predictions = compute_predictions(
model,
self.processor,
self.test_dataset,
language=self.language,
batch_size=self.batch_size,
num_workers=self.num_workers,
)
test_cer = jiwer.cer(references, predictions)
prediction_rows = build_prediction_rows(references, predictions)
checkpoint_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
metrics = {
"global_step": state.global_step,
"test_cer": test_cer,
"num_samples": len(references),
}
(checkpoint_dir / "test_cer.json").write_text(
json.dumps(metrics, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
with (checkpoint_dir / "test_predictions.jsonl").open(
"w", encoding="utf-8"
) as f:
for row in prediction_rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
with (Path(self.output_dir) / "test_cer_history.jsonl").open(
"a", encoding="utf-8"
) as f:
f.write(json.dumps(metrics, ensure_ascii=False) + "\n")
is_best = self.best_test_cer is None or test_cer < self.best_test_cer
if is_best:
previous_best_dir = self.best_checkpoint_dir
self.best_test_cer = test_cer
self.best_checkpoint_dir = checkpoint_dir
(Path(self.output_dir) / "best_test_cer.json").write_text(
json.dumps(
{
"global_step": state.global_step,
"test_cer": test_cer,
"checkpoint_dir": str(checkpoint_dir),
},
ensure_ascii=False,
indent=2,
)
+ "\n",
encoding="utf-8",
)
if (
self.save_best_test_cer_only
and previous_best_dir is not None
and previous_best_dir != checkpoint_dir
and previous_best_dir.exists()
):
shutil.rmtree(previous_best_dir)
elif self.save_best_test_cer_only and checkpoint_dir.exists():
shutil.rmtree(checkpoint_dir)
maybe_log_wandb_predictions(
self.wandb_module,
stage=f"checkpoint_{state.global_step}",
metric_name="cer",
metric_value=test_cer,
rows=prediction_rows,
max_rows=self.wandb_log_samples,
)
if is_best:
print(f"Test CER at step {state.global_step}: {test_cer * 100:.3f}% (best)")
else:
print(f"Test CER at step {state.global_step}: {test_cer * 100:.3f}%")
return control
def take_first_n(ds, n: int):
if n is None or n <= 0:
return ds
return ds.select(range(min(n, len(ds))))
def _load_single_manifest(manifest_path: Path, audio_root: Path | None, args):
cer_results_index = getattr(args, "cer_results_index", {})
dataset = load_dataset("json", data_files={"manifest": str(manifest_path)})[
"manifest"
]
before_count = len(dataset)
dataset = dataset.filter(
should_keep_local_record,
fn_kwargs={
"min_cer": args.min_cer,
"max_cer": args.max_cer,
"results_min_cer": args.results_min_cer,
"results_max_cer": args.results_max_cer,
"min_cer_percent": args.min_cer_percent,
"max_cer_percent": args.max_cer_percent,
"min_duration_sec": args.min_duration_sec,
"max_duration_sec": args.max_duration_sec,
"min_text_len": args.min_text_len,
"max_text_len": args.max_text_len,
"manifest_path": manifest_path,
"audio_root": audio_root,
"cer_results_index": cer_results_index,
},
)
if cer_results_index and (
args.results_min_cer is not None or args.results_max_cer is not None
):
removed_count = before_count - len(dataset)
print(
f"Filtered {removed_count} rows from {manifest_path} using external CER thresholds"
)
dataset = dataset.map(
lambda example: {
"audio": resolve_audio_path(
example["audio_path"], manifest_path, audio_root
),
"text": str(example["text"]).strip(),
},
remove_columns=[
col for col in dataset.column_names if col not in ["audio_path", "text"]
],
)
return dataset
def load_local_dataset(args):
local_manifest = Path(args.local_manifest).resolve()
audio_root = Path(args.audio_root).resolve() if args.audio_root else None
args.cer_results_index = load_cer_results_index(args.cer_results_jsonl)
if local_manifest.is_dir():
manifest_paths = sorted(local_manifest.rglob("manifest.jsonl"))
if not manifest_paths:
raise ValueError(f"No manifest.jsonl files found under {local_manifest}")
non_empty_manifest_paths = [p for p in manifest_paths if p.stat().st_size > 0]
skipped_manifest_count = len(manifest_paths) - len(non_empty_manifest_paths)
print(f"Found {len(manifest_paths)} manifest files under {local_manifest}")
if skipped_manifest_count:
print(f"Skipping {skipped_manifest_count} empty manifest files")
if not non_empty_manifest_paths:
raise ValueError(
f"All manifest.jsonl files under {local_manifest} are empty"
)
dataset = concatenate_datasets(
[
_load_single_manifest(p, audio_root, args)
for p in non_empty_manifest_paths
]
)
else:
if local_manifest.stat().st_size == 0:
raise ValueError(f"Manifest file is empty: {local_manifest}")
dataset = _load_single_manifest(local_manifest, audio_root, args)
if len(dataset) < 3:
raise ValueError(
f"Local dataset is too small after filtering: {len(dataset)} rows. Need at least 3."
)
split_seed = args.split_seed
test_size = args.test_split_ratio
val_size = args.val_split_ratio / (1.0 - test_size)
first_split = dataset.train_test_split(test_size=test_size, seed=split_seed)
train_and_val = first_split["train"]
test_dataset = first_split["test"]
second_split = train_and_val.train_test_split(test_size=val_size, seed=split_seed)
train_dataset = second_split["train"]
val_dataset = second_split["test"]
train_dataset = take_first_n(train_dataset, args.train_size)
val_dataset = take_first_n(val_dataset, args.val_size)
test_dataset = take_first_n(test_dataset, args.test_size)
return train_dataset, val_dataset, test_dataset
def load_eval_json_dataset(
path: str,
*,
audio_root: str | None,
audio_column: str,
text_column: str,
size_limit: int | None,
):
manifest_path = Path(path).resolve()
resolved_audio_root = Path(audio_root).resolve() if audio_root else None
dataset = load_dataset("json", data_files={"eval": str(manifest_path)})["eval"]
dataset = dataset.map(
lambda example: {
"audio": resolve_row_audio_path(
example,
audio_column=audio_column,
manifest_path=manifest_path,
audio_root=resolved_audio_root,
),
"text": str(example[text_column]).strip(),
},
remove_columns=[
col
for col in dataset.column_names
if col not in [audio_column, text_column]
],
)
return take_first_n(dataset, size_limit)
def load_training_datasets(args):
if args.local_manifest:
print(f"Loading local manifest: {args.local_manifest}")
return load_local_dataset(args), "local"
print(f"Loading dataset: {args.dataset} ({args.dataset_subset})")
dataset = load_dataset(args.dataset, args.dataset_subset)
train_dataset = take_first_n(dataset["train"], args.train_size)
val_dataset = take_first_n(dataset["validation"], args.val_size)
test_dataset = take_first_n(dataset["test"], args.test_size)
return (train_dataset, val_dataset, test_dataset), "hf"
# ---------------------------------------------------------------------------
# Training parameter selection
# ---------------------------------------------------------------------------
def should_train_parameter(
name: str, *, train_last_n_layers: int, total_lm_layers: int
) -> bool:
if "projector" in name or "lora" in name or "lm_head" in name:
return True
if train_last_n_layers <= 0:
return False
layer_prefix = "language_model.model.layers."
if not name.startswith(layer_prefix):
return False
remainder = name[len(layer_prefix):]
layer_index_str = remainder.split(".", 1)[0]
if not layer_index_str.isdigit():
return False
layer_index = int(layer_index_str)
first_trainable_layer = max(total_lm_layers - train_last_n_layers, 0)
return layer_index >= first_trainable_layer
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def parse_args():
parser = argparse.ArgumentParser(description="Finetune IBM Granite Speech")
parser.add_argument("--model-name", default="ibm-granite/granite-4.0-1b-speech")
parser.add_argument("--dataset", default="speechcolab/gigaspeech")
parser.add_argument("--dataset-subset", default="xs")
parser.add_argument("--local-manifest", default=None)
parser.add_argument("--audio-root", default=None)
parser.add_argument("--validation-json", default=None)
parser.add_argument("--test-json", default=None)
parser.add_argument("--eval-audio-root", default=None)
parser.add_argument("--eval-audio-column", default="audio")
parser.add_argument("--eval-text-column", default="text")
parser.add_argument("--train-size", type=int, default=5000)
parser.add_argument("--val-size", type=int, default=200)
parser.add_argument("--test-size", type=int, default=200)
parser.add_argument("--val-split-ratio", type=float, default=0.01)
parser.add_argument("--test-split-ratio", type=float, default=0.01)
parser.add_argument("--split-seed", type=int, default=42)
parser.add_argument("--language", default=None, choices=["en", "ja"])
parser.add_argument("--eval-metric", default=None, choices=["wer", "cer"])
parser.add_argument("--min-cer", type=float, default=None)
parser.add_argument("--max-cer", type=float, default=None)
parser.add_argument("--results-min-cer", type=float, default=None)
parser.add_argument("--results-max-cer", type=float, default=None)
parser.add_argument("--cer-results-jsonl", default=None)
parser.add_argument("--min-cer-percent", type=float, default=None)
parser.add_argument("--max-cer-percent", type=float, default=None)
parser.add_argument("--min-duration-sec", type=float, default=None)
parser.add_argument("--max-duration-sec", type=float, default=None)
parser.add_argument("--min-text-len", type=int, default=None)
parser.add_argument("--max-text-len", type=int, default=None)
parser.add_argument("--output-dir", default="save_dir")
parser.add_argument("--use-wandb", action="store_true")
parser.add_argument("--wandb-project", default="granite-asr")
parser.add_argument("--wandb-entity", default=None)
parser.add_argument("--wandb-run-name", default=None)
parser.add_argument("--wandb-log-samples", type=int, default=50)
parser.add_argument("--num-epochs", type=float, default=1.0)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--grad-accum-steps", type=int, default=2)
parser.add_argument("--learning-rate", type=float, default=3e-5)
parser.add_argument("--warmup-ratio", type=float, default=0.2)
parser.add_argument("--lr-scheduler-type", type=str, default="cosine")
parser.add_argument("--train-last-n-layers", type=int, default=0)
parser.add_argument("--dataloader-num-workers", type=int, default=16)
parser.add_argument("--save-steps", dest="save_steps", type=int, default=200)
parser.add_argument("--save-total-limit", dest="save_total_limit", type=int, default=None)
parser.add_argument("--save-best-test-cer-only", action="store_true")
parser.add_argument("--skip-eval", action="store_true")
return parser.parse_args()
def main():
args = parse_args()
if args.train_last_n_layers < 0:
raise ValueError("--train-last-n-layers must be >= 0")
if args.save_best_test_cer_only and args.skip_eval:
raise ValueError("--save-best-test-cer-only requires evaluation; remove --skip-eval")
if args.language is None:
args.language = "ja" if args.local_manifest else "en"
wandb_module = maybe_init_wandb(args)
if args.local_manifest:
if not (0.0 < args.test_split_ratio < 1.0):
raise ValueError("--test-split-ratio must be between 0 and 1")
if not (0.0 < args.val_split_ratio < 1.0):
raise ValueError("--val-split-ratio must be between 0 and 1")
if args.val_split_ratio + args.test_split_ratio >= 1.0:
raise ValueError("--val-split-ratio + --test-split-ratio must be < 1")
if args.eval_metric is None:
args.eval_metric = "cer" if args.language == "ja" else "wer"
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "args.json").write_text(json.dumps(vars(args), indent=2, default=str))
(train_dataset, val_dataset, test_dataset), dataset_kind = load_training_datasets(args)
if args.validation_json:
val_dataset = load_eval_json_dataset(
args.validation_json,
audio_root=args.eval_audio_root,
audio_column=args.eval_audio_column,
text_column=args.eval_text_column,
size_limit=args.val_size,
)
if args.validation_json or args.test_json:
test_json_path = args.test_json or args.validation_json
test_dataset = load_eval_json_dataset(
test_json_path,
audio_root=args.eval_audio_root,
audio_column=args.eval_audio_column,
text_column=args.eval_text_column,
size_limit=args.test_size,
)
print(f"Dataset sizes -> train: {len(train_dataset)}, val: {len(val_dataset)}, test: {len(test_dataset)}")
processor = GraniteSpeechProcessor.from_pretrained(args.model_name)
model = GraniteSpeechForConditionalGeneration.from_pretrained(
args.model_name, dtype=torch.bfloat16
)
train_dataset = prepare_dataset(train_dataset, processor, dataset_kind=dataset_kind, language=args.language)
val_dataset = prepare_dataset(val_dataset, processor, dataset_kind=dataset_kind, language=args.language)
test_dataset = prepare_dataset(test_dataset, processor, dataset_kind=dataset_kind, language=args.language)
if not args.skip_eval:
references_before, predictions_before = compute_predictions(
model, processor, test_dataset,
language=args.language, batch_size=args.batch_size, num_workers=args.dataloader_num_workers,
)
metric = evaluate.load(args.eval_metric)
metric_before = metric.compute(references=references_before, predictions=predictions_before)
print(f"{args.eval_metric.upper()} before finetuning: {metric_before * 100:.3f}%")
total_lm_layers = len(model.language_model.model.layers)
if args.train_last_n_layers > total_lm_layers:
raise ValueError(f"--train-last-n-layers exceeds total LM layers ({total_lm_layers})")
for name, param in model.named_parameters():
param.requires_grad = should_train_parameter(
name,
train_last_n_layers=args.train_last_n_layers,
total_lm_layers=total_lm_layers,
)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")
training_args = TrainingArguments(
output_dir=args.output_dir,
remove_unused_columns=False,
report_to="wandb" if args.use_wandb else "none",
bf16=True,
eval_strategy="steps",
save_strategy="steps",
eval_steps=0.1,
save_steps=args.save_steps,
save_total_limit=args.save_total_limit,
dataloader_num_workers=args.dataloader_num_workers,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum_steps,
num_train_epochs=args.num_epochs,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type=args.lr_scheduler_type,
logging_steps=0.1,
learning_rate=args.learning_rate,
data_seed=42,
)
data_collator = GraniteCollator(processor)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
processing_class=processor,
)
test_cer_callback = None
if not args.skip_eval:
test_cer_callback = TestCerCallback(
processor=processor,
test_dataset=test_dataset,
language=args.language,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers,
output_dir=args.output_dir,
wandb_module=wandb_module,
wandb_log_samples=args.wandb_log_samples,
save_best_test_cer_only=args.save_best_test_cer_only,
)
trainer.add_callback(test_cer_callback)
trainer.train()
if args.save_best_test_cer_only:
keep_checkpoint_dir = test_cer_callback.best_checkpoint_dir if test_cer_callback else None
prune_checkpoints(args.output_dir, keep_checkpoint_dir)
trainer.save_model(args.output_dir)
processor.save_pretrained(args.output_dir)
if not args.skip_eval:
references_after, predictions_after = compute_predictions(
model, processor, test_dataset,
language=args.language, batch_size=args.batch_size, num_workers=args.dataloader_num_workers,
)
metric = evaluate.load(args.eval_metric)
metric_after = metric.compute(references=references_after, predictions=predictions_after)
print(f"{args.eval_metric.upper()} after finetuning: {metric_after * 100:.3f}%")
print(f"{args.eval_metric.upper()} improvement: {(metric_before - metric_after) * 100:.3f}%")
if wandb_module is not None:
wandb_module.finish()
if __name__ == "__main__":
main()
The manifest JSONL format expected by the script:
{
"audio_path": "path/to/audio.wav",
"text": "Transcription text here",
"duration_sec": 5.2
}
audio_path can be an absolute path or relative to --audio-root.
Summary of Lessons Learned
What worked well
- Unfreezing lm_head + last 8 LM layers: The single biggest driver of accuracy improvement
- Training data with punctuation: Eliminated a post-processing step and simplified the pipeline
- TestCerCallback for automatic best-checkpoint saving: Preserved the best result without manual monitoring
- CER-based data filtering: Removing low-quality samples stabilized training noticeably
What didn't pan out / remaining challenges
- Proper nouns and domain-specific terms still show high error rates — adding domain-specific training data is the next step
- Going beyond
n=8caused overfitting with the current data volume - Surpassing Qwen3-ASR will likely require more training data
FAQ
Q: Does the Audio Encoder need to be finetuned?
The Audio Encoder is already well-adapted to multilingual speech in a general sense, so freezing it is standard practice for this scale of finetuning. Unfreezing it adds a large number of parameters and significantly increases overfitting risk and compute cost. At hundreds or thousands of hours of training data, encoder finetuning becomes worth considering.
Q: How much GPU VRAM is needed?
With batch_size=8, bfloat16, and the last 8 layers unfrozen, around 24GB VRAM (RTX 3090/4090 range) is a reasonable estimate. Increasing grad_accum_steps can compensate for smaller VRAM by maintaining the effective batch size.
Q: Does this approach work with HuggingFace datasets like GigaSpeech?
Yes. Instead of a local manifest, specify a HuggingFace dataset ID and it will use the same training pipeline. For English datasets, pass --language en and adjust the text normalization in build_eval_normalizer accordingly.
Q: How should Japanese text normalization be handled for evaluation?
Using a MeCab-based normalizer to unify full-width/half-width characters, standardize number representations, and normalize punctuation improves CER evaluation stability. The build_eval_normalizer function in the script is the place to plug in your normalization logic.
Contact
For project inquiries and collaboration, contact us here.
If you are considering a new project, product development, or other collaboration, please get in touch.
Related Articles
Explore more articles connected to this topic.
Suppressing LLM Repetition Hallucinations with a Custom Logits Processor in Qwen
Fine-tuning Qwen for structured JSON output can trigger repetition hallucinations where the model loops the same phrase indefinitely. Standard parameters like no_repeat_ngram_size suppress this globally, causing unintended side effects across the entire output. This article implements a custom Transformers LogitsProcessor that applies repetition control exclusively inside the target JSON field, eliminating loops without breaking the surrounding structure.
Read article →Japanese ASR Model Comparison 2026: Whisper, Qwen3, Voxtral & ReazonSpeech Benchmarked on RTX5090
Benchmarking 8 Japanese ASR models under identical conditions on an RTX5090, qwen/qwen3-asr-1.7b (WER: 0.1899) and whisper (WER: 0.2099) clearly lead in both accuracy and stability. For speed-critical workloads, parakeet-tdt-0.6b-v3 (RTF: 0.002) is the fastest by a wide margin, while reazonspeech-espnet-v2 is the strongest choice for Japanese broadcast/media domains. To ensure fair comparison across models, WER was computed using MeCab morphological tokenization combined with punctuation stripping before scoring.
Read article →