update decoding commands

This commit is contained in:
marcoyang 2022-11-02 17:25:31 +08:00
parent 86662f0b97
commit 0a46a39e24
3 changed files with 88 additions and 145 deletions

View File

@ -91,6 +91,21 @@ Usage:
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) modified beam search (with RNNLM shallow fusion)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir /path/to/RNNLM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
"""
@ -121,7 +136,6 @@ from beam_search import (
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model
from icefall import NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -389,8 +403,6 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]:
@ -526,11 +538,12 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_sf_rnnlm":
hyp_tokens = modified_beam_search_sf_rnnlm_batched(
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
@ -586,9 +599,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnnlm: Optional[NgramLm] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -642,8 +653,6 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
@ -731,7 +740,7 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_sf_rnnlm",
"modified_beam_search_rnnlm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -942,8 +951,6 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
)

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -656,6 +657,7 @@ class Hypothesis:
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
state: Optional=None
lm_score: Optional=None
@ -1542,107 +1544,6 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ans[key] = hyps
return ans
def modified_beam_search_sf_rnnlm(
model: Transducer,
encoder_out: torch.Tensor,
sp,
rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4,
):
encoder_out = model.joiner.encoder_proj(encoder_out)
lm_scale = rnnlm_scale
assert rnnlm is not None
assert encoder_out.ndim == 2, encoder_out.shape
rnnlm.clean_cache()
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
eos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
T = encoder_out.shape[0]
for t in range(T):
current_encoder_out = encoder_out[t : t + 1]
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyp in A]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, joiner_dim)
current_encoder_out = current_encoder_out.repeat(len(A), 1)
# current_encoder_out is of shape (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(
beam
) # get topk tokens and scores
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[hyp_idx] # get hyp
new_ys = hyp.ys[:]
state = "ys=" + "+".join(list(map(str, new_ys)))
tokens = k2.RaggedTensor([new_ys[context_size:]])
lm_score = rnnlm.predict(
tokens, state, sos_id, eos_id, blank_id
) # get rnnlm score
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k] # get token
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
# state_cost = hyp.state_cost.forward_one_step(new_token)
hyp_log_prob += (
lm_score[new_token] * lm_scale
) # add the lm score
else:
new_ys = new_ys
new_log_prob = hyp_log_prob
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
return best_hyp.ys[context_size:]
def modified_beam_search_rnnlm_shallow_fusion(
model: Transducer,

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -19,47 +20,43 @@
"""
Usage:
(1) greedy search
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./lstm_transducer_stateless2/decode.py \
--epoch 30 \
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
@ -67,12 +64,11 @@ Usage:
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
@ -80,17 +76,34 @@ Usage:
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) modified beam search with RNNLM shallow fusion (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 4 \
--max-contexts 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM/exp \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
"""
@ -243,6 +256,16 @@ def get_parser():
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
@ -294,6 +317,15 @@ def get_parser():
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--rnn-lm-scale",
@ -517,6 +549,9 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -708,7 +743,7 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_sf_rnnlm",
"modified_beam_search_rnnlm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -843,7 +878,7 @@ def main():
rnn_lm_model = None
rnn_lm_scale = params.rnn_lm_scale
if params.decoding_method == "modified_beam_search3":
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,