update author info

This commit is contained in:
marcoyang 2022-11-02 17:27:31 +08:00
parent 0a46a39e24
commit babcfd4b68

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -91,7 +92,7 @@ Usage:
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) modified beam search (with RNNLM shallow fusion)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
@ -105,7 +106,7 @@ Usage:
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
--rnn-lm-tie-weights 1
"""
@ -131,7 +132,6 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
modified_beam_search_rnnlm_shallow_fusion,
)
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model
@ -386,11 +386,7 @@ def get_parser():
last output linear layer
""",
)
parser.add_argument(
"--ilm-scale",
type=float,
default=-0.1
)
parser.add_argument("--ilm-scale", type=float, default=-0.1)
add_model_arguments(parser)
return parser
@ -642,9 +638,13 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]])
logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}")
total_duration = sum(
[cut.duration for cut in batch["supervisions"]["cut"]]
)
logging.info(
f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}"
)
hyps_dict = decode_one_batch(
params=params,
@ -765,10 +765,10 @@ def main():
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if "rnnlm" in params.decoding_method:
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if "ILME" in params.decoding_method:
params.suffix += f"-ILME-scale={params.ilm_scale}"
@ -903,7 +903,7 @@ def main():
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
else:
rnn_lm_model = None