mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
update author info
This commit is contained in:
parent
0a46a39e24
commit
babcfd4b68
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -91,7 +92,7 @@ Usage:
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
|
||||
(8) modified beam search (with RNNLM shallow fusion)
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 35 \
|
||||
@ -105,7 +106,7 @@ Usage:
|
||||
--rnn-lm-epoch 99 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
--rnn-lm-tie-weights 1
|
||||
"""
|
||||
|
||||
|
||||
@ -131,7 +132,6 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_rnnlm_shallow_fusion,
|
||||
|
||||
)
|
||||
from librispeech import LibriSpeech
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
@ -386,11 +386,7 @@ def get_parser():
|
||||
last output linear layer
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ilm-scale",
|
||||
type=float,
|
||||
default=-0.1
|
||||
)
|
||||
parser.add_argument("--ilm-scale", type=float, default=-0.1)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -642,9 +638,13 @@ def decode_dataset(
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
total_duration = sum([cut.duration for cut in batch["supervisions"]["cut"]])
|
||||
|
||||
logging.info(f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}")
|
||||
total_duration = sum(
|
||||
[cut.duration for cut in batch["supervisions"]["cut"]]
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}"
|
||||
)
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -765,10 +765,10 @@ def main():
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
|
||||
if "rnnlm" in params.decoding_method:
|
||||
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
|
||||
|
||||
|
||||
if "ILME" in params.decoding_method:
|
||||
params.suffix += f"-ILME-scale={params.ilm_scale}"
|
||||
|
||||
@ -903,7 +903,7 @@ def main():
|
||||
)
|
||||
rnn_lm_model.to(device)
|
||||
rnn_lm_model.eval()
|
||||
|
||||
|
||||
else:
|
||||
rnn_lm_model = None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user