diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 18012d241..17b63a659 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -526,6 +526,8 @@ def fast_beam_search( project_input=False, ) ilme_logits = ilme_logits.squeeze(1).squeeze(1) + if blank_penalty != 0: + ilme_logits[:, 0] -= blank_penalty ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) log_probs -= ilme_scale * ilme_log_probs diff --git a/egs/wenetspeech/ASR/zipformer/decode.py b/egs/wenetspeech/ASR/zipformer/decode.py index aa22213a6..9a752b7ca 100755 --- a/egs/wenetspeech/ASR/zipformer/decode.py +++ b/egs/wenetspeech/ASR/zipformer/decode.py @@ -100,7 +100,7 @@ from beam_search import ( modified_beam_search, ) from lhotse.cut import Cut -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_model, get_params from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( @@ -227,6 +227,16 @@ def get_parser(): """, ) + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -381,6 +391,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, ) for hyp in hyp_tokens: sentence = "".join([lexicon.word_table[i] for i in hyp]) @@ -458,6 +469,7 @@ def decode_one_batch( key += f"_num_paths_{params.num_paths}_" key += f"nbest_scale_{params.nbest_scale}" if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} @@ -624,6 +636,7 @@ def main(): params.suffix += f"-nbest-scale-{params.nbest_scale}" params.suffix += f"-num-paths-{params.num_paths}" if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -656,7 +669,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: @@ -739,7 +752,7 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": + if "LG" in params.decoding_method: lexicon = Lexicon(params.lang_dir) lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") @@ -782,6 +795,9 @@ def main(): test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_dl = [dev_dl, test_net_dl, test_meeting_dl] + test_sets = ["TEST_NET"] + test_dl = [test_net_dl] + for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( dl=test_dl, diff --git a/egs/wenetspeech/ASR/zipformer/streaming_decode.py b/egs/wenetspeech/ASR/zipformer/streaming_decode.py index 01ce86d98..94c5fae5f 100755 --- a/egs/wenetspeech/ASR/zipformer/streaming_decode.py +++ b/egs/wenetspeech/ASR/zipformer/streaming_decode.py @@ -50,7 +50,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -761,7 +761,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 3b37b208a..83dbfa22f 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -66,7 +66,7 @@ from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model import Transducer +from model import AsrModel from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling @@ -578,20 +578,19 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner -def get_transducer_model(params: AttributeDict) -> nn.Module: +def get_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) - model = Transducer( + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, encoder_dim=int(max(params.encoder_dim.split(","))), decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, ) return model @@ -758,7 +757,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, _ = model( x=feature, x_lens=feature_lens, y=y, @@ -1086,7 +1085,7 @@ def run(rank, world_size, args): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")