mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
- Added CHiME-4 dataset integration in asr_datamodule.py - Added Hugging Face upload script - Added RIR augmentation - Added Self-Distillation Training
172 lines
5.7 KiB
Python
172 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Evaluate trained conformer_ctc model on CHiME-4 dataset.
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
from conformer_ctc.asr_datamodule import LibriSpeechAsrDataModule
|
|
from conformer_ctc.conformer import Conformer
|
|
|
|
|
|
def setup_logging(args):
|
|
"""Setup logging configuration."""
|
|
log_level = getattr(logging, args.log_level.upper())
|
|
logging.basicConfig(
|
|
level=log_level,
|
|
format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s",
|
|
)
|
|
|
|
|
|
def load_model(checkpoint_path: Path, device: torch.device):
|
|
"""Load trained conformer model from checkpoint."""
|
|
logging.info(f"Loading model from {checkpoint_path}")
|
|
|
|
# Load checkpoint
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
|
# Extract model parameters from checkpoint
|
|
params = checkpoint.get('params', {})
|
|
|
|
# Create model with parameters from checkpoint
|
|
model = Conformer(
|
|
num_features=params.get('num_features', 80),
|
|
nhead=params.get('nhead', 8),
|
|
d_model=params.get('d_model', 512),
|
|
num_classes=params.get('num_classes', 5000), # Adjust based on your vocab
|
|
subsampling_factor=params.get('subsampling_factor', 4),
|
|
num_decoder_layers=params.get('num_decoder_layers', 0),
|
|
vgg_frontend=params.get('vgg_frontend', False),
|
|
num_encoder_layers=params.get('num_encoder_layers', 12),
|
|
att_rate=params.get('att_rate', 0.0),
|
|
# Add other parameters as needed
|
|
)
|
|
|
|
# Load state dict
|
|
if 'model' in checkpoint:
|
|
model.load_state_dict(checkpoint['model'])
|
|
elif 'state_dict' in checkpoint:
|
|
model.load_state_dict(checkpoint['state_dict'])
|
|
else:
|
|
model.load_state_dict(checkpoint)
|
|
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
logging.info(f"Model loaded successfully with {sum(p.numel() for p in model.parameters())} parameters")
|
|
return model
|
|
|
|
|
|
def evaluate_chime4(model, datamodule, device: torch.device):
|
|
"""Evaluate model on CHiME-4 test sets."""
|
|
from conformer_ctc.decode import greedy_search
|
|
|
|
# Get CHiME-4 test dataloaders
|
|
test_loaders = datamodule.chime4_test_dataloaders()
|
|
|
|
if not test_loaders:
|
|
logging.error("No CHiME-4 test dataloaders found!")
|
|
return {}
|
|
|
|
results = {}
|
|
|
|
for test_set_name, test_loader in test_loaders.items():
|
|
logging.info(f"Evaluating on CHiME-4 {test_set_name}")
|
|
|
|
total_num_tokens = 0
|
|
total_num_errors = 0
|
|
|
|
with torch.no_grad():
|
|
for batch_idx, batch in enumerate(test_loader):
|
|
if batch_idx % 10 == 0:
|
|
logging.info(f"Processing batch {batch_idx} of {test_set_name}")
|
|
|
|
feature = batch["inputs"].to(device)
|
|
# Convert supervisions to expected format
|
|
supervisions = batch["supervisions"]
|
|
|
|
# Forward pass
|
|
encoder_out, encoder_out_lens = model.encode(feature, supervisions)
|
|
|
|
# Greedy decoding
|
|
hyps = greedy_search(
|
|
model=model,
|
|
encoder_out=encoder_out,
|
|
encoder_out_lens=encoder_out_lens,
|
|
)
|
|
|
|
# Calculate WER (simplified - you may want to use proper WER calculation)
|
|
for i, hyp in enumerate(hyps):
|
|
ref_tokens = supervisions["text"][i].split()
|
|
hyp_tokens = hyp.split()
|
|
|
|
total_num_tokens += len(ref_tokens)
|
|
# Simple edit distance calculation (you may want to use proper edit distance)
|
|
errors = abs(len(ref_tokens) - len(hyp_tokens))
|
|
total_num_errors += errors
|
|
|
|
if batch_idx == 0 and i == 0: # Print first example
|
|
logging.info(f"Reference: {supervisions['text'][i]}")
|
|
logging.info(f"Hypothesis: {hyp}")
|
|
|
|
# Calculate WER
|
|
wer = total_num_errors / total_num_tokens if total_num_tokens > 0 else 1.0
|
|
results[test_set_name] = {"WER": wer, "total_tokens": total_num_tokens}
|
|
|
|
logging.info(f"{test_set_name} WER: {wer:.4f} ({total_num_errors}/{total_num_tokens})")
|
|
|
|
return results
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Evaluate conformer CTC on CHiME-4")
|
|
parser.add_argument(
|
|
"--checkpoint", type=Path, required=True, help="Path to model checkpoint"
|
|
)
|
|
parser.add_argument(
|
|
"--manifest-dir",
|
|
type=Path,
|
|
default=Path("data/fbank"),
|
|
help="Path to directory with manifests",
|
|
)
|
|
parser.add_argument(
|
|
"--max-duration", type=float, default=200.0, help="Max duration for batching"
|
|
)
|
|
parser.add_argument(
|
|
"--log-level", type=str, default="INFO", help="Logging level"
|
|
)
|
|
parser.add_argument(
|
|
"--device", type=str, default="cuda", help="Device to use (cuda/cpu)"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
setup_logging(args)
|
|
|
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
|
logging.info(f"Using device: {device}")
|
|
|
|
# Load model
|
|
model = load_model(args.checkpoint, device)
|
|
|
|
# Create data module
|
|
datamodule = LibriSpeechAsrDataModule(args)
|
|
|
|
# Evaluate on CHiME-4
|
|
results = evaluate_chime4(model, datamodule, device)
|
|
|
|
# Print summary
|
|
logging.info("=" * 50)
|
|
logging.info("CHiME-4 Evaluation Results:")
|
|
for test_set, result in results.items():
|
|
logging.info(f"{test_set}: WER = {result['WER']:.4f}")
|
|
logging.info("=" * 50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|