diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index c54d57a79..35e64c743 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -30,7 +30,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import load_checkpoint from icefall.decode import ( get_lattice, nbest_decoding, @@ -46,8 +46,10 @@ from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, get_texts, + load_averaged_model, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -174,19 +176,47 @@ def get_parser(): """, ) + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True share the weights between the input embedding layer and the + last output linear layer + """, + ) + return parser def get_rnn_lm_model(params: AttributeDict): from rnn_lm.model import RnnLmModel - # TODO: Pass the following options from command-line rnn_lm_model = RnnLmModel( vocab_size=params.num_classes, - embedding_dim=1024, - hidden_dim=1024, - num_layers=2, - tie_weights=False, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, ) return rnn_lm_model @@ -727,20 +757,16 @@ def main(): if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + rnn_lm_model = None if params.method == "rnn-lm": rnn_lm_model = get_rnn_lm_model(params) if params.rnn_lm_avg == 1: @@ -748,19 +774,16 @@ def main(): f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", rnn_lm_model, ) - else: - start = params.rnn_lm_epoch - params.rnn_lm_avg + 1 - filenames = [] - for i in range(start, params.rnn_lm_epoch + 1): - if start >= 0: - filenames.append(f"{params.rnn_lm_exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") rnn_lm_model.to(device) - rnn_lm_model.load_state_dict( - average_checkpoints(filenames, device=device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, ) - else: - rnn_lm_model = None + rnn_lm_model.eval() librispeech = LibriSpeechAsrDataModule(args) diff --git a/egs/librispeech/ASR/rnn_lm/train.py b/egs/librispeech/ASR/rnn_lm/train.py index 27fc237b4..5df9cfdb2 100755 --- a/egs/librispeech/ASR/rnn_lm/train.py +++ b/egs/librispeech/ASR/rnn_lm/train.py @@ -206,7 +206,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 200, "reset_interval": 2000, - "valid_interval": 30000, + "valid_interval": 10000, "env_info": get_env_info(), } ) @@ -539,6 +539,7 @@ def run(rank, world_size, args): embedding_dim=params.embedding_dim, hidden_dim=params.hidden_dim, num_layers=params.num_layers, + tie_weights=params.tie_weights, ) checkpoints = load_checkpoint_if_available(params=params, model=model) diff --git a/icefall/decode.py b/icefall/decode.py index d9a5df453..cc366c3e1 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -22,6 +22,33 @@ import torch from icefall.utils import add_eos, add_sos, get_texts +DEFAULT_LM_SCALE = [ + 0.01, + 0.05, + 0.08, + 0.1, + 0.3, + 0.5, + 0.6, + 0.7, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + 1.5, + 1.7, + 1.9, + 2.0, + 2.1, + 2.2, + 2.3, + 2.5, + 3.0, + 4.0, + 5.0, +] + def _intersect_device( a_fsas: k2.Fsa, @@ -1082,28 +1109,17 @@ def rescore_with_rnn_lm( rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) - if ngram_lm_scale is None: - ngram_lm_scale_list = [0.01, 0.05, 0.08] - ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] - ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] - else: + ngram_lm_scale_list = DEFAULT_LM_SCALE + attention_scale_list = DEFAULT_LM_SCALE + rnn_lm_scale_list = DEFAULT_LM_SCALE + + if ngram_lm_scale: ngram_lm_scale_list = [ngram_lm_scale] - if attention_scale is None: - attention_scale_list = [0.01, 0.05, 0.08] - attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] - attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] - else: + if attention_scale: attention_scale_list = [attention_scale] - if rnn_lm_scale is None: - rnn_lm_scale_list = [0.01, 0.05, 0.08] - rnn_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - rnn_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] - rnn_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] - else: + if rnn_lm_scale: rnn_lm_scale_list = [rnn_lm_scale] ans = dict() diff --git a/icefall/utils.py b/icefall/utils.py index 61854847c..10a2e6301 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -35,6 +35,8 @@ import torch.distributed as dist import torch.nn as nn from torch.utils.tensorboard import SummaryWriter +from icefall.checkpoint import average_checkpoints + Pathlike = Union[str, Path] @@ -806,4 +808,35 @@ def optim_step_and_measure_param_change( p_orig = old_parameters[n] delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) relative_change[n] = delta.item() - return relative_change \ No newline at end of file + return relative_change + + +def load_averaged_model( + model_dir: str, + model: torch.nn.Module, + epoch: int, + avg: int, + device: torch.device, +): + """ + Load a model which is the average of all checkpoints + + :param model_dir: a str of the experiment directory + :param model: a torch.nn.Module instance + + :param epoch: the last epoch to load from + :param avg: how many models to average from + :param device: move model to this device + + :return: A model averaged + """ + + # start cannot be negative + start = max(epoch - avg + 1, 0) + filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)] + + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + return model