mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
update decoding commands
This commit is contained in:
parent
86662f0b97
commit
0a46a39e24
@ -91,6 +91,21 @@ Usage:
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(8) modified beam search (with RNNLM shallow fusion)
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
|
||||
--beam 4 \
|
||||
--rnn-lm-scale 0.3 \
|
||||
--rnn-lm-exp-dir /path/to/RNNLM \
|
||||
--rnn-lm-epoch 99 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
"""
|
||||
|
||||
|
||||
@ -121,7 +136,6 @@ from beam_search import (
|
||||
from librispeech import LibriSpeech
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall import NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
@ -389,8 +403,6 @@ def decode_one_batch(
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
@ -526,11 +538,12 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search_sf_rnnlm":
|
||||
hyp_tokens = modified_beam_search_sf_rnnlm_batched(
|
||||
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
@ -586,9 +599,7 @@ def decode_dataset(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
rnnlm: Optional[NgramLm] = None,
|
||||
rnnlm: Optional[RnnLmModel] = None,
|
||||
rnnlm_scale: float = 1.0,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
@ -642,8 +653,6 @@ def decode_dataset(
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
@ -731,7 +740,7 @@ def main():
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_sf_rnnlm",
|
||||
"modified_beam_search_rnnlm_shallow_fusion",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -942,8 +951,6 @@ def main():
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
rnnlm=rnn_lm_model,
|
||||
rnnlm_scale=rnn_lm_scale,
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -656,6 +657,7 @@ class Hypothesis:
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
state: Optional=None
|
||||
|
||||
lm_score: Optional=None
|
||||
|
||||
@ -1542,107 +1544,6 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
ans[key] = hyps
|
||||
|
||||
return ans
|
||||
|
||||
def modified_beam_search_sf_rnnlm(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
sp,
|
||||
rnnlm: RnnLmModel,
|
||||
rnnlm_scale: float,
|
||||
beam: int = 4,
|
||||
):
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
lm_scale = rnnlm_scale
|
||||
|
||||
assert rnnlm is not None
|
||||
assert encoder_out.ndim == 2, encoder_out.shape
|
||||
rnnlm.clean_cache()
|
||||
blank_id = model.decoder.blank_id
|
||||
sos_id = sp.piece_to_id("<sos/eos>")
|
||||
eos_id = sp.piece_to_id("<sos/eos>")
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
T = encoder_out.shape[0]
|
||||
for t in range(T):
|
||||
current_encoder_out = encoder_out[t : t + 1]
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyp in A]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyp in A],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
# decoder_out is of shape (num_hyps, joiner_dim)
|
||||
current_encoder_out = current_encoder_out.repeat(len(A), 1)
|
||||
# current_encoder_out is of shape (num_hyps, encoder_out_dim)
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
log_probs = log_probs.reshape(-1)
|
||||
topk_log_probs, topk_indexes = log_probs.topk(
|
||||
beam
|
||||
) # get topk tokens and scores
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[hyp_idx] # get hyp
|
||||
new_ys = hyp.ys[:]
|
||||
state = "ys=" + "+".join(list(map(str, new_ys)))
|
||||
tokens = k2.RaggedTensor([new_ys[context_size:]])
|
||||
|
||||
lm_score = rnnlm.predict(
|
||||
tokens, state, sos_id, eos_id, blank_id
|
||||
) # get rnnlm score
|
||||
|
||||
hyp_log_prob = topk_log_probs[k] # get score of current hyp
|
||||
new_token = topk_token_indexes[k] # get token
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
# state_cost = hyp.state_cost.forward_one_step(new_token)
|
||||
hyp_log_prob += (
|
||||
lm_score[new_token] * lm_scale
|
||||
) # add the lm score
|
||||
else:
|
||||
new_ys = new_ys
|
||||
new_log_prob = hyp_log_prob
|
||||
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
return best_hyp.ys[context_size:]
|
||||
|
||||
def modified_beam_search_rnnlm_shallow_fusion(
|
||||
model: Transducer,
|
||||
|
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -19,47 +20,43 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 35 \
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 35 \
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 35 \
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 35 \
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 30 \
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
@ -67,12 +64,11 @@ Usage:
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 35 \
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
@ -80,17 +76,34 @@ Usage:
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./lstm_transducer_stateless2/decode.py \
|
||||
--epoch 35 \
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(8) modified beam search with RNNLM shallow fusion (with LG)
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch 35 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--rnn-lm-scale 0.4 \
|
||||
--rnn-lm-exp-dir /path/to/RNNLM/exp \
|
||||
--rnn-lm-epoch 99 \
|
||||
--rnn-lm-avg 1 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--rnn-lm-tie-weights 1
|
||||
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@ -243,6 +256,16 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
@ -294,6 +317,15 @@ def get_parser():
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-lm-scale",
|
||||
@ -517,6 +549,9 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
sp=sp,
|
||||
rnnlm=rnnlm,
|
||||
rnnlm_scale=rnnlm_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -708,7 +743,7 @@ def main():
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_sf_rnnlm",
|
||||
"modified_beam_search_rnnlm_shallow_fusion",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
@ -843,7 +878,7 @@ def main():
|
||||
|
||||
rnn_lm_model = None
|
||||
rnn_lm_scale = params.rnn_lm_scale
|
||||
if params.decoding_method == "modified_beam_search3":
|
||||
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
|
||||
rnn_lm_model = RnnLmModel(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.rnn_lm_embedding_dim,
|
||||
|
Loading…
x
Reference in New Issue
Block a user