From babcfd4b68a0f6729161eb1aa0c10e2c2aea2764 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:27:31 +0800 Subject: [PATCH] update author info --- .../ASR/lstm_transducer_stateless2/decode.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index c43328e08..fc077f062 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,7 +92,7 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 - + (8) modified beam search (with RNNLM shallow fusion) ./lstm_transducer_stateless2/decode.py \ --epoch 35 \ @@ -105,7 +106,7 @@ Usage: --rnn-lm-epoch 99 \ --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ - --rnn-lm-tie-weights 1 + --rnn-lm-tie-weights 1 """ @@ -131,7 +132,6 @@ from beam_search import ( greedy_search_batch, modified_beam_search, modified_beam_search_rnnlm_shallow_fusion, - ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model @@ -386,11 +386,7 @@ def get_parser(): last output linear layer """, ) - parser.add_argument( - "--ilm-scale", - type=float, - default=-0.1 - ) + parser.add_argument("--ilm-scale", type=float, default=-0.1) add_model_arguments(parser) return parser @@ -642,9 +638,13 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}") + total_duration = sum( + [cut.duration for cut in batch["supervisions"]["cut"]] + ) + + logging.info( + f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}" + ) hyps_dict = decode_one_batch( params=params, @@ -765,10 +765,10 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - + if "rnnlm" in params.decoding_method: params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - + if "ILME" in params.decoding_method: params.suffix += f"-ILME-scale={params.ilm_scale}" @@ -903,7 +903,7 @@ def main(): ) rnn_lm_model.to(device) rnn_lm_model.eval() - + else: rnn_lm_model = None