diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py index d51a4b76b..28b908f82 100755 --- a/icefall/rnn_lm/check-onnx-streaming.py +++ b/icefall/rnn_lm/check-onnx-streaming.py @@ -112,7 +112,6 @@ def main(): for torch_v, onnx_v in zip( (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0) ): - assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( torch_v.shape, onnx_v.shape, diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py index dfede708b..1070d443a 100755 --- a/icefall/rnn_lm/export-onnx.py +++ b/icefall/rnn_lm/export-onnx.py @@ -5,16 +5,16 @@ import argparse import logging from pathlib import Path +from typing import Dict import onnx import torch from model import RnnLmModel from onnxruntime.quantization import QuantType, quantize_dynamic +from train import get_params from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, str2bool -from typing import Dict -from train import get_params def add_meta_data(filename: str, meta_data: Dict[str, str]): diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index dadf23009..78a914cf4 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Yifan Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -25,7 +27,12 @@ from pathlib import Path import torch from model import RnnLmModel -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) from icefall.utils import AttributeDict, str2bool @@ -37,18 +44,10 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=29, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", ) parser.add_argument( @@ -61,6 +60,35 @@ def get_parser(): """, ) + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnnlm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + parser.add_argument( "--vocab-size", type=int, @@ -98,20 +126,14 @@ def get_parser(): """, ) - parser.add_argument( - "--exp-dir", - type=str, - default="rnn_lm/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - parser.add_argument( "--jit", type=str2bool, - default=True, + default=False, help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. """, ) @@ -144,37 +166,82 @@ def main(): model.to(device) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to("cpu") model.eval()