diff --git a/egs/librispeech/ASR/zipformer_ctc/export.py b/egs/librispeech/ASR/zipformer_ctc/export.py index fbcbd7b29..0ff50f128 100755 --- a/egs/librispeech/ASR/zipformer_ctc/export.py +++ b/egs/librispeech/ASR/zipformer_ctc/export.py @@ -24,11 +24,17 @@ import logging from pathlib import Path import torch -from conformer import Conformer +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_ctc_model, get_params -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import str2bool def get_parser(): @@ -39,24 +45,45 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=34, + default=30, help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, - default=20, + default=9, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch'. ", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, - default="conformer_ctc/exp", + default="zipformer_ctc/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """, @@ -78,23 +105,11 @@ def get_parser(): """, ) + add_model_arguments(parser) + return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 80, - "subsampling_factor": 4, - "use_feat_batchnorm": True, - "attention_dim": 512, - "nhead": 8, - "num_decoder_layers": 6, - } - ) - return params - - def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) @@ -108,6 +123,7 @@ def main(): lexicon = Lexicon(params.lang_dir) max_token_id = max(lexicon.tokens) num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes device = torch.device("cpu") if torch.cuda.is_available(): @@ -115,34 +131,95 @@ def main(): logging.info(f"device: {device}") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - model.to(device) + model = get_ctc_model(params) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to("cpu") model.eval() if params.jit: logging.info("Using torch.jit.script") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + convert_scaled_to_non_scaled(model, inplace=True) model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" model.save(str(filename)) diff --git a/egs/librispeech/ASR/zipformer_ctc/model.py b/egs/librispeech/ASR/zipformer_ctc/model.py index 13efc8d75..560845339 100644 --- a/egs/librispeech/ASR/zipformer_ctc/model.py +++ b/egs/librispeech/ASR/zipformer_ctc/model.py @@ -60,6 +60,7 @@ class CTCModel(nn.Module): ) self.decoder = decoder + @torch.jit.ignore def forward( self, x: torch.Tensor,