update author info

This commit is contained in:
marcoyang 2022-11-02 17:27:31 +08:00
parent 0a46a39e24
commit babcfd4b68

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao) # Zengwei Yao,
# Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -91,7 +92,7 @@ Usage:
--beam 20.0 \ --beam 20.0 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
(8) modified beam search (with RNNLM shallow fusion) (8) modified beam search (with RNNLM shallow fusion)
./lstm_transducer_stateless2/decode.py \ ./lstm_transducer_stateless2/decode.py \
--epoch 35 \ --epoch 35 \
@ -105,7 +106,7 @@ Usage:
--rnn-lm-epoch 99 \ --rnn-lm-epoch 99 \
--rnn-lm-avg 1 \ --rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \ --rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1 --rnn-lm-tie-weights 1
""" """
@ -131,7 +132,6 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_rnnlm_shallow_fusion, modified_beam_search_rnnlm_shallow_fusion,
) )
from librispeech import LibriSpeech from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -386,11 +386,7 @@ def get_parser():
last output linear layer last output linear layer
""", """,
) )
parser.add_argument( parser.add_argument("--ilm-scale", type=float, default=-0.1)
"--ilm-scale",
type=float,
default=-0.1
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -642,9 +638,13 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]]) total_duration = sum(
[cut.duration for cut in batch["supervisions"]["cut"]]
logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}") )
logging.info(
f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}"
)
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -765,10 +765,10 @@ def main():
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if "rnnlm" in params.decoding_method: if "rnnlm" in params.decoding_method:
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if "ILME" in params.decoding_method: if "ILME" in params.decoding_method:
params.suffix += f"-ILME-scale={params.ilm_scale}" params.suffix += f"-ILME-scale={params.ilm_scale}"
@ -903,7 +903,7 @@ def main():
) )
rnn_lm_model.to(device) rnn_lm_model.to(device)
rnn_lm_model.eval() rnn_lm_model.eval()
else: else:
rnn_lm_model = None rnn_lm_model = None