icefall/egs/librispeech/ASR/evaluate_chime4.py
jaeeunbaik 915e8e399c Add CHiME-4 dataset, RIR and Self-Distillation
- Added CHiME-4 dataset integration in asr_datamodule.py
- Added Hugging Face upload script
- Added RIR augmentation
- Added Self-Distillation Training
2025-08-27 16:11:20 +09:00

172 lines
5.7 KiB
Python

#!/usr/bin/env python3
"""
Evaluate trained conformer_ctc model on CHiME-4 dataset.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, List
import torch
from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule
from conformer_ctc.conformer import Conformer
def setup_logging(args):
"""Setup logging configuration."""
log_level = getattr(logging, args.log_level.upper())
logging.basicConfig(
level=log_level,
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
)
def load_model(checkpoint_path: Path, device: torch.device):
"""Load trained conformer model from checkpoint."""
logging.info(f"Loading model from {checkpoint_path}")
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
# Extract model parameters from checkpoint
params = checkpoint.get('params', {})
# Create model with parameters from checkpoint
model = Conformer(
num_features=params.get('num_features', 80),
nhead=params.get('nhead', 8),
d_model=params.get('d_model', 512),
num_classes=params.get('num_classes', 5000), # Adjust based on your vocab
subsampling_factor=params.get('subsampling_factor', 4),
num_decoder_layers=params.get('num_decoder_layers', 0),
vgg_frontend=params.get('vgg_frontend', False),
num_encoder_layers=params.get('num_encoder_layers', 12),
att_rate=params.get('att_rate', 0.0),
# Add other parameters as needed
)
# Load state dict
if 'model' in checkpoint:
model.load_state_dict(checkpoint['model'])
elif 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint)
model = model.to(device)
model.eval()
logging.info(f"Model loaded successfully with {sum(p.numel() for p in model.parameters())} parameters")
return model
def evaluate_chime4(model, datamodule, device: torch.device):
"""Evaluate model on CHiME-4 test sets."""
from conformer_ctc.decode import greedy_search
# Get CHiME-4 test dataloaders
test_loaders = datamodule.chime4_test_dataloaders()
if not test_loaders:
logging.error("No CHiME-4 test dataloaders found!")
return {}
results = {}
for test_set_name, test_loader in test_loaders.items():
logging.info(f"Evaluating on CHiME-4 {test_set_name}")
total_num_tokens = 0
total_num_errors = 0
with torch.no_grad():
for batch_idx, batch in enumerate(test_loader):
if batch_idx % 10 == 0:
logging.info(f"Processing batch {batch_idx} of {test_set_name}")
feature = batch["inputs"].to(device)
# Convert supervisions to expected format
supervisions = batch["supervisions"]
# Forward pass
encoder_out, encoder_out_lens = model.encode(feature, supervisions)
# Greedy decoding
hyps = greedy_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
# Calculate WER (simplified - you may want to use proper WER calculation)
for i, hyp in enumerate(hyps):
ref_tokens = supervisions["text"][i].split()
hyp_tokens = hyp.split()
total_num_tokens += len(ref_tokens)
# Simple edit distance calculation (you may want to use proper edit distance)
errors = abs(len(ref_tokens) - len(hyp_tokens))
total_num_errors += errors
if batch_idx == 0 and i == 0: # Print first example
logging.info(f"Reference: {supervisions['text'][i]}")
logging.info(f"Hypothesis: {hyp}")
# Calculate WER
wer = total_num_errors / total_num_tokens if total_num_tokens > 0 else 1.0
results[test_set_name] = {"WER": wer, "total_tokens": total_num_tokens}
logging.info(f"{test_set_name} WER: {wer:.4f} ({total_num_errors}/{total_num_tokens})")
return results
def main():
parser = argparse.ArgumentParser(description="Evaluate conformer CTC on CHiME-4")
parser.add_argument(
"--checkpoint", type=Path, required=True, help="Path to model checkpoint"
)
parser.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with manifests",
)
parser.add_argument(
"--max-duration", type=float, default=200.0, help="Max duration for batching"
)
parser.add_argument(
"--log-level", type=str, default="INFO", help="Logging level"
)
parser.add_argument(
"--device", type=str, default="cuda", help="Device to use (cuda/cpu)"
)
args = parser.parse_args()
setup_logging(args)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
# Load model
model = load_model(args.checkpoint, device)
# Create data module
datamodule = LibriSpeechAsrDataModule(args)
# Evaluate on CHiME-4
results = evaluate_chime4(model, datamodule, device)
# Print summary
logging.info("=" * 50)
logging.info("CHiME-4 Evaluation Results:")
for test_set, result in results.items():
logging.info(f"{test_set}: WER = {result['WER']:.4f}")
logging.info("=" * 50)
if __name__ == "__main__":
main()