#!/usr/bin/env python3 # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) # (still working in progress) import argparse import logging from pathlib import Path import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.utils import ( AttributeDict, get_texts, setup_logger, store_transcripts, write_error_stats, ) def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), "lang_dir": Path("data/lang/bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, "nhead": 8, "attention_dim": 512, "num_classes": 5000, "subsampling_factor": 4, "num_decoder_layers": 6, "vgg_frontend": False, "is_espnet_structure": True, "mmi_loss": False, "use_feat_batchnorm": True, "search_beam": 20, "output_beam": 5, "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, # Possible values for method: # - 1best # - nbest # - nbest-rescoring # - whole-lattice-rescoring "method": "whole-lattice-rescoring", # num_paths is used when method is "nbest" and "nbest-rescoring" "num_paths": 30, } ) return params def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--epoch", type=int, default=9, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=1, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) return parser @torch.no_grad() def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() params = get_params() params.update(vars(args)) setup_logger(f"{params.exp_dir}/log/log-decode") logging.info("Decoding started") logging.info(params) device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) model = Conformer( num_features=params.feature_dim, nhead=params.nhead, d_model=params.attention_dim, num_classes=params.num_classes, subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=params.vgg_frontend, is_espnet_structure=params.is_espnet_structure, mmi_loss=params.mmi_loss, use_feat_batchnorm=params.use_feat_batchnorm, ) if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.load_state_dict(average_checkpoints(filenames)) model.to(device) model.eval() token_ids_with_blank = list(range(params.num_classes)) if __name__ == "__main__": main()