Merge pull request #645 from marcoyang1998/master

Support RNNLM shallow fusion in modified beam search
This commit is contained in:
marcoyang1998 2022-11-04 11:39:12 +08:00 committed by GitHub
commit 7c50a019b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 2511 additions and 30 deletions

View File

@ -101,6 +101,7 @@ The WERs are:
|-------------------------------------|------------|------------|-------------------------| |-------------------------------------|------------|------------|-------------------------|
| greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 | | greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 |
| modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 | | modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 |
| modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 |
| fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 | | fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 |
| greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 | | greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 |
| modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 |
@ -155,6 +156,27 @@ for m in greedy_search fast_beam_search modified_beam_search; do
done done
``` ```
To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM
can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
for iter in 472000; do
for avg in 8 10 12 14 16 18; do
./lstm_transducer_stateless2/decode.py \
--iter $iter \
--avg $avg \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir /path/to/RNNLM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
done
done
Pretrained models, training logs, decoding logs, and decoding results Pretrained models, training logs, decoding logs, and decoding results
are available at are available at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03> <https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>
@ -1311,6 +1333,7 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di
|-------------------------------------|------------|------------|-----------------------------------------| |-------------------------------------|------------|------------|-----------------------------------------|
| greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | | greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 | | modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search + RNNLM shallow fusion | 2.27 | 5.24 | --epoch 30 --avg 10 --max-duration 600 |
| fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 | | fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
```bash ```bash
@ -1356,6 +1379,36 @@ for method in greedy_search modified_beam_search fast_beam_search; do
done done
``` ```
To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM
can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
```bash
for method in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless5/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless5/exp-B \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--max-sym-per-frame 1 \
--num-encoder-layers 24 \
--dim-feedforward 1536 \
--nhead 8 \
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512 \
--use-averaged-model True
--beam 4 \
--max-contexts 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM/exp \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
done
```
You can find a pretrained model, training logs, decoding logs, and decoding You can find a pretrained model, training logs, decoding logs, and decoding
results at: results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-B-2022-07-07> <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-B-2022-07-07>

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao) # Zengwei Yao,
# Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -91,6 +92,21 @@ Usage:
--beam 20.0 \ --beam 20.0 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
(8) modified beam search (with RNNLM shallow fusion)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir /path/to/RNNLM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
""" """
@ -116,6 +132,7 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_ngram_rescoring, modified_beam_search_ngram_rescoring,
modified_beam_search_rnnlm_shallow_fusion,
) )
from librispeech import LibriSpeech from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -128,6 +145,7 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -217,6 +235,7 @@ def get_parser():
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring - modified_beam_search_ngram_rescoring
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`. `--lang-dir`, which should contain `LG.pt`.
""", """,
@ -306,6 +325,71 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
parser.add_argument(
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument( parser.add_argument(
"--tokens-ngram", "--tokens-ngram",
type=int, type=int,
@ -336,6 +420,8 @@ def decode_one_batch(
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None, ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0, ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -480,6 +566,18 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -531,6 +629,8 @@ def decode_dataset(
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None, ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0, ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -582,6 +682,8 @@ def decode_dataset(
batch=batch, batch=batch,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -668,6 +770,7 @@ def main():
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"modified_beam_search_ngram_rescoring", "modified_beam_search_ngram_rescoring",
"modified_beam_search_rnnlm_shallow_fusion",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -693,6 +796,8 @@ def main():
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if "rnnlm" in params.decoding_method:
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -806,6 +911,8 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# only load N-gram LM when needed
if "ngram" in params.decoding_method:
lm_filename = f"{params.tokens_ngram}gram.fst.txt" lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}") logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm( ngram_lm = NgramLm(
@ -814,6 +921,33 @@ def main():
is_binary=False, is_binary=False,
) )
logging.info(f"num states: {ngram_lm.lm.num_states}") logging.info(f"num states: {ngram_lm.lm.num_states}")
else:
ngram_lm = None
ngram_lm_scale = None
# only load rnnlm if used
if "rnnlm" in params.decoding_method:
rnn_lm_scale = params.rnn_lm_scale
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
assert params.rnn_lm_avg == 1
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
else:
rnn_lm_model = None
rnn_lm_scale = 0.0
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": if params.decoding_method == "fast_beam_search_nbest_LG":
@ -860,7 +994,9 @@ def main():
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
ngram_lm=ngram_lm, ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale, ngram_lm_scale=ngram_lm_scale,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
) )
save_results( save_results(

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -16,7 +17,7 @@
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
@ -25,6 +26,7 @@ from model import Transducer
from icefall import NgramLm, NgramLmStateCost from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import ( from icefall.utils import (
DecodingResults, DecodingResults,
add_eos, add_eos,
@ -729,6 +731,13 @@ class Hypothesis:
# on which ys[i] is decoded # on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list) timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None state_cost: Optional[NgramLmStateCost] = None
@property @property
@ -1851,3 +1860,249 @@ def modified_beam_search_ngram_rescoring(
ans.append(sorted_ans[unsorted_indices[i]]) ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans
def modified_beam_search_rnnlm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + RNNLM shallow fusion
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
rnnlm (RnnLmModel):
RNNLM
rnnlm_scale (float):
scale of RNNLM in shallow fusion
beam (int, optional):
Beam size. Defaults to 4.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = []
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
# forward RNNLM to get new states and scores
if len(token_list) != 0:
tokens_to_score = (
torch.tensor(token_list)
.to(torch.int64)
.to(device)
.reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += (
lm_score[new_token] * lm_scale
) # add the lm score
lm_score = scores[count]
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestampe=new_timestamp,
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao) # Zengwei Yao,
# Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -25,7 +26,6 @@ Usage:
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
@ -34,7 +34,6 @@ Usage:
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
@ -43,7 +42,6 @@ Usage:
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (one best) (4) fast beam search (one best)
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
@ -54,7 +52,6 @@ Usage:
--beam 20.0 \ --beam 20.0 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
@ -67,7 +64,6 @@ Usage:
--max-states 64 \ --max-states 64 \
--num-paths 200 \ --num-paths 200 \
--nbest-scale 0.5 --nbest-scale 0.5
(6) fast beam search (nbest oracle WER) (6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
@ -80,7 +76,6 @@ Usage:
--max-states 64 \ --max-states 64 \
--num-paths 200 \ --num-paths 200 \
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG) (7) fast beam search (with LG)
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 28 \
@ -91,6 +86,24 @@ Usage:
--beam 20.0 \ --beam 20.0 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
(8) modified beam search with RNNLM shallow fusion (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 4 \
--max-contexts 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM/exp \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
""" """
@ -115,6 +128,7 @@ from beam_search import (
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_rnnlm_shallow_fusion,
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -125,6 +139,7 @@ from icefall.checkpoint import (
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -214,6 +229,7 @@ def get_parser():
- fast_beam_search_nbest - fast_beam_search_nbest
- fast_beam_search_nbest_oracle - fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG - fast_beam_search_nbest_LG
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`. `--lang-dir`, which should contain `LG.pt`.
""", """,
@ -251,6 +267,20 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument( parser.add_argument(
"--max-contexts", "--max-contexts",
type=int, type=int,
@ -276,6 +306,7 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
@ -312,19 +343,69 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--decode-chunk-size", "--rnn-lm-scale",
type=int, type=float,
default=16, default=0.0,
help="The chunk size for decoding (in frames after subsampling)", help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
) )
parser.add_argument( parser.add_argument(
"--left-context", "--rnn-lm-exp-dir",
type=int, type=str,
default=64, default="rnn_lm/exp",
help="left context can be seen during decoding (in frames after subsampling)", help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
) )
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -337,6 +418,8 @@ def decode_one_batch(
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -482,6 +565,18 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -531,6 +626,8 @@ def decode_dataset(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None, word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -572,6 +669,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
logging.info(f"Decoding {batch_idx}-th batch")
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -580,6 +678,8 @@ def decode_dataset(
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table, word_table=word_table,
batch=batch, batch=batch,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
) )
for name, hyps in hyps_dict.items(): for name, hyps in hyps_dict.items():
@ -666,6 +766,7 @@ def main():
"fast_beam_search_nbest_LG", "fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle", "fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
"modified_beam_search_rnnlm_shallow_fusion",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -673,11 +774,9 @@ def main():
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming: if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}" params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
@ -695,6 +794,8 @@ def main():
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if params.use_averaged_model: if params.use_averaged_model:
params.suffix += "-use-averaged-model" params.suffix += "-use-averaged-model"
@ -805,6 +906,25 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
rnn_lm_model = None
rnn_lm_scale = params.rnn_lm_scale
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
assert params.rnn_lm_avg == 1
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method: if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
@ -848,6 +968,8 @@ def main():
sp=sp, sp=sp,
word_table=word_table, word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
) )
save_results( save_results(

View File

@ -19,7 +19,7 @@ import logging
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from icefall.utils import make_pad_mask from icefall.utils import add_eos, add_sos, make_pad_mask
class RnnLmModel(torch.nn.Module): class RnnLmModel(torch.nn.Module):
@ -72,6 +72,8 @@ class RnnLmModel(torch.nn.Module):
else: else:
logging.info("Not tying weights") logging.info("Not tying weights")
self.cache = {}
def forward( def forward(
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
@ -118,3 +120,95 @@ class RnnLmModel(torch.nn.Module):
nll_loss = nll_loss.reshape(batch_size, -1) nll_loss = nll_loss.reshape(batch_size, -1)
return nll_loss return nll_loss
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
device = next(self.parameters()).device
batch_size = len(token_lens)
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = (
sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
)
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64).to(device)
y_tokens = y_tokens.to(torch.int64).to(device)
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True
rnn_out, states = self.rnn(embedding)
logits = self.output_linear(rnn_out)
mask = torch.zeros(logits.shape).bool().to(device)
for i in range(batch_size):
mask[i, token_lens[i], :] = True
logits = logits[mask].reshape(batch_size, -1)
return logits[:, :].log_softmax(-1), states
def clean_cache(self):
self.cache = {}
def score_token(self, tokens: torch.Tensor, state=None):
device = next(self.parameters()).device
batch_size = tokens.size(0)
if state:
h, c = state
else:
h = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
).to(device)
c = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
).to(device)
embedding = self.input_embedding(tokens)
rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out)
return logits[:, 0].log_softmax(-1), states
def forward_with_state(
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
):
batch_size = len(token_lens)
if state:
h, c = state
else:
h = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
)
c = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
)
device = next(self.parameters()).device
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = (
sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
)
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64).to(device)
y_tokens = y_tokens.to(torch.int64).to(device)
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True
rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out)
return logits, states