Merge pull request #645 from marcoyang1998/master

Support RNNLM shallow fusion in modified beam search
This commit is contained in:
marcoyang1998 2022-11-04 11:39:12 +08:00 committed by GitHub
commit 7c50a019b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 2511 additions and 30 deletions

View File

@ -101,6 +101,7 @@ The WERs are:
|-------------------------------------|------------|------------|-------------------------|
| greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 |
| modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 |
| modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 |
| fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 |
| greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 |
| modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 |
@ -155,6 +156,27 @@ for m in greedy_search fast_beam_search modified_beam_search; do
done
```
To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM
can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
for iter in 472000; do
for avg in 8 10 12 14 16 18; do
./lstm_transducer_stateless2/decode.py \
--iter $iter \
--avg $avg \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir /path/to/RNNLM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
done
done
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>
@ -1311,6 +1333,7 @@ layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder di
|-------------------------------------|------------|------------|-----------------------------------------|
| greedy search (max sym per frame 1) | 2.54 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search | 2.47 | 5.71 | --epoch 30 --avg 10 --max-duration 600 |
| modified beam search + RNNLM shallow fusion | 2.27 | 5.24 | --epoch 30 --avg 10 --max-duration 600 |
| fast beam search | 2.5 | 5.72 | --epoch 30 --avg 10 --max-duration 600 |
```bash
@ -1356,6 +1379,36 @@ for method in greedy_search modified_beam_search fast_beam_search; do
done
```
To decode with RNNLM shallow fusion, use the following decoding command. A well-trained RNNLM
can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>
```bash
for method in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless5/decode.py \
--epoch 30 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless5/exp-B \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--max-sym-per-frame 1 \
--num-encoder-layers 24 \
--dim-feedforward 1536 \
--nhead 8 \
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512 \
--use-averaged-model True
--beam 4 \
--max-contexts 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM/exp \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
done
```
You can find a pretrained model, training logs, decoding logs, and decoding
results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-B-2022-07-07>

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -91,6 +92,21 @@ Usage:
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) modified beam search (with RNNLM shallow fusion)
./lstm_transducer_stateless2/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search_rnnlm_shallow_fusion \
--beam 4 \
--rnn-lm-scale 0.3 \
--rnn-lm-exp-dir /path/to/RNNLM \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
"""
@ -116,6 +132,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
modified_beam_search_ngram_rescoring,
modified_beam_search_rnnlm_shallow_fusion,
)
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model
@ -128,6 +145,7 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
setup_logger,
@ -217,6 +235,7 @@ def get_parser():
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -306,6 +325,71 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
@ -336,6 +420,8 @@ def decode_one_batch(
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -480,6 +566,18 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
@ -531,6 +629,8 @@ def decode_dataset(
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -582,6 +682,8 @@ def decode_dataset(
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for name, hyps in hyps_dict.items():
@ -668,6 +770,7 @@ def main():
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_ngram_rescoring",
"modified_beam_search_rnnlm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -693,6 +796,8 @@ def main():
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
if "rnnlm" in params.decoding_method:
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -806,6 +911,8 @@ def main():
model.to(device)
model.eval()
# only load N-gram LM when needed
if "ngram" in params.decoding_method:
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
@ -814,6 +921,33 @@ def main():
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
else:
ngram_lm = None
ngram_lm_scale = None
# only load rnnlm if used
if "rnnlm" in params.decoding_method:
rnn_lm_scale = params.rnn_lm_scale
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
assert params.rnn_lm_avg == 1
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
else:
rnn_lm_model = None
rnn_lm_scale = 0.0
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
@ -860,7 +994,9 @@ def main():
word_table=word_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale,
ngram_lm_scale=ngram_lm_scale,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
)
save_results(

View File

@ -1,4 +1,5 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -16,7 +17,7 @@
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import k2
import sentencepiece as spm
@ -25,6 +26,7 @@ from model import Transducer
from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
DecodingResults,
add_eos,
@ -729,6 +731,13 @@ class Hypothesis:
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None
@property
@ -1851,3 +1860,249 @@ def modified_beam_search_ngram_rescoring(
ans.append(sorted_ans[unsorted_indices[i]])
return ans
def modified_beam_search_rnnlm_shallow_fusion(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
sp: spm.SentencePieceProcessor,
rnnlm: RnnLmModel,
rnnlm_scale: float,
beam: int = 4,
return_timestamps: bool = False,
) -> List[List[int]]:
"""Modified_beam_search + RNNLM shallow fusion
Args:
model (Transducer):
The transducer model
encoder_out (torch.Tensor):
Encoder output in (N,T,C)
encoder_out_lens (torch.Tensor):
A 1-D tensor of shape (N,), containing the number of
valid frames in encoder_out before padding.
sp:
Sentence piece generator.
rnnlm (RnnLmModel):
RNNLM
rnnlm_scale (float):
scale of RNNLM in shallow fusion
beam (int, optional):
Beam size. Defaults to 4.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
assert rnnlm is not None
lm_scale = rnnlm_scale
vocab_size = rnnlm.vocab_size
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
sos_id = sp.piece_to_id("<sos/eos>")
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
# get initial lm score and lm state by scoring the "sos" token
sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device)
init_score, init_states = rnnlm.score_token(sos_token)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state=init_states,
lm_score=init_score.reshape(-1),
timestamp=[],
)
)
rnnlm.clean_cache()
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
"""
for all hyps with a non-blank new token, score this token.
It is a little confusing here because this for-loop
looks very similar to the one below. Here, we go through all
top-k tokens and only add the non-blanks ones to the token_list.
The RNNLM will score those tokens given the LM states. Note that
the variable `scores` is the LM score after seeing the new
non-blank token.
"""
token_list = []
hs = []
cs = []
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
assert new_token != 0, new_token
token_list.append([new_token])
# store the LSTM states
hs.append(hyp.state[0])
cs.append(hyp.state[1])
# forward RNNLM to get new states and scores
if len(token_list) != 0:
tokens_to_score = (
torch.tensor(token_list)
.to(torch.int64)
.to(device)
.reshape(-1, 1)
)
hs = torch.cat(hs, dim=1).to(device)
cs = torch.cat(cs, dim=1).to(device)
scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs))
count = 0 # index, used to locate score and lm states
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
ys = hyp.ys[:]
lm_score = hyp.lm_score
state = hyp.state
hyp_log_prob = topk_log_probs[k] # get score of current hyp
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)
hyp_log_prob += (
lm_score[new_token] * lm_scale
) # add the lm score
lm_score = scores[count]
state = (
lm_states[0][:, count, :].unsqueeze(1),
lm_states[1][:, count, :].unsqueeze(1),
)
count += 1
new_hyp = Hypothesis(
ys=ys,
log_prob=hyp_log_prob,
state=state,
lm_score=lm_score,
timestampe=new_timestamp,
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
tokens=ans,
timestamps=ans_timestamps,
)

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -25,7 +26,6 @@ Usage:
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
@ -34,7 +34,6 @@ Usage:
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
@ -43,7 +42,6 @@ Usage:
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
@ -54,7 +52,6 @@ Usage:
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
@ -67,7 +64,6 @@ Usage:
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
@ -80,7 +76,6 @@ Usage:
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
@ -91,6 +86,24 @@ Usage:
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) modified beam search with RNNLM shallow fusion (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 4 \
--max-contexts 4 \
--rnn-lm-scale 0.4 \
--rnn-lm-exp-dir /path/to/RNNLM/exp \
--rnn-lm-epoch 99 \
--rnn-lm-avg 1 \
--rnn-lm-num-layers 3 \
--rnn-lm-tie-weights 1
"""
@ -115,6 +128,7 @@ from beam_search import (
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_rnnlm_shallow_fusion,
)
from train import add_model_arguments, get_params, get_transducer_model
@ -125,6 +139,7 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.rnn_lm.model import RnnLmModel
from icefall.utils import (
AttributeDict,
setup_logger,
@ -214,6 +229,7 @@ def get_parser():
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
@ -251,6 +267,20 @@ def get_parser():
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--max-contexts",
type=int,
@ -276,6 +306,7 @@ def get_parser():
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
@ -312,19 +343,69 @@ def get_parser():
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
"--rnn-lm-scale",
type=float,
default=0.0,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
"--rnn-lm-exp-dir",
type=str,
default="rnn_lm/exp",
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the path to RNN LM exp dir.
""",
)
parser.add_argument(
"--rnn-lm-epoch",
type=int,
default=7,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the checkpoint to use.
""",
)
parser.add_argument(
"--rnn-lm-avg",
type=int,
default=2,
help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion.
It specifies the number of checkpoints to average.
""",
)
parser.add_argument(
"--rnn-lm-embedding-dim",
type=int,
default=2048,
help="Embedding dim of the model",
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
)
parser.add_argument(
"--rnn-lm-num-layers",
type=int,
default=4,
help="Number of RNN layers the model",
)
parser.add_argument(
"--rnn-lm-tie-weights",
type=str2bool,
default=False,
help="""True to share the weights between the input embedding layer and the
last output linear layer
""",
)
add_model_arguments(parser)
return parser
@ -337,6 +418,8 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -482,6 +565,18 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
hyp_tokens = modified_beam_search_rnnlm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
sp=sp,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
@ -531,6 +626,8 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
rnnlm: Optional[RnnLmModel] = None,
rnnlm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -572,6 +669,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
logging.info(f"Decoding {batch_idx}-th batch")
hyps_dict = decode_one_batch(
params=params,
@ -580,6 +678,8 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
rnnlm=rnnlm,
rnnlm_scale=rnnlm_scale,
)
for name, hyps in hyps_dict.items():
@ -666,6 +766,7 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_rnnlm_shallow_fusion",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -673,11 +774,9 @@ def main():
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
@ -695,6 +794,8 @@ def main():
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -805,6 +906,25 @@ def main():
model.to(device)
model.eval()
rnn_lm_model = None
rnn_lm_scale = params.rnn_lm_scale
if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion":
rnn_lm_model = RnnLmModel(
vocab_size=params.vocab_size,
embedding_dim=params.rnn_lm_embedding_dim,
hidden_dim=params.rnn_lm_hidden_dim,
num_layers=params.rnn_lm_num_layers,
tie_weights=params.rnn_lm_tie_weights,
)
assert params.rnn_lm_avg == 1
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
)
rnn_lm_model.to(device)
rnn_lm_model.eval()
if "fast_beam_search" in params.decoding_method:
if "LG" in params.decoding_method:
lexicon = Lexicon(params.lang_dir)
@ -848,6 +968,8 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
rnnlm=rnn_lm_model,
rnnlm_scale=rnn_lm_scale,
)
save_results(

View File

@ -19,7 +19,7 @@ import logging
import torch
import torch.nn.functional as F
from icefall.utils import make_pad_mask
from icefall.utils import add_eos, add_sos, make_pad_mask
class RnnLmModel(torch.nn.Module):
@ -72,6 +72,8 @@ class RnnLmModel(torch.nn.Module):
else:
logging.info("Not tying weights")
self.cache = {}
def forward(
self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor
) -> torch.Tensor:
@ -118,3 +120,95 @@ class RnnLmModel(torch.nn.Module):
nll_loss = nll_loss.reshape(batch_size, -1)
return nll_loss
def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id):
device = next(self.parameters()).device
batch_size = len(token_lens)
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = (
sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
)
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64).to(device)
y_tokens = y_tokens.to(torch.int64).to(device)
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True
rnn_out, states = self.rnn(embedding)
logits = self.output_linear(rnn_out)
mask = torch.zeros(logits.shape).bool().to(device)
for i in range(batch_size):
mask[i, token_lens[i], :] = True
logits = logits[mask].reshape(batch_size, -1)
return logits[:, :].log_softmax(-1), states
def clean_cache(self):
self.cache = {}
def score_token(self, tokens: torch.Tensor, state=None):
device = next(self.parameters()).device
batch_size = tokens.size(0)
if state:
h, c = state
else:
h = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
).to(device)
c = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
).to(device)
embedding = self.input_embedding(tokens)
rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out)
return logits[:, 0].log_softmax(-1), states
def forward_with_state(
self, tokens, token_lens, sos_id, eos_id, blank_id, state=None
):
batch_size = len(token_lens)
if state:
h, c = state
else:
h = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
)
c = torch.zeros(
self.rnn.num_layers, batch_size, self.rnn.input_size
)
device = next(self.parameters()).device
sos_tokens = add_sos(tokens, sos_id)
tokens_eos = add_eos(tokens, eos_id)
sos_tokens_row_splits = sos_tokens.shape.row_splits(1)
sentence_lengths = (
sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1]
)
x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id)
y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id)
x_tokens = x_tokens.to(torch.int64).to(device)
y_tokens = y_tokens.to(torch.int64).to(device)
sentence_lengths = sentence_lengths.to(torch.int64).to(device)
embedding = self.input_embedding(x_tokens)
# Note: We use batch_first==True
rnn_out, states = self.rnn(embedding, (h, c))
logits = self.output_linear(rnn_out)
return logits, states