mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Modified beam search with RNNLM rescoring (#1002)
* add RNNLM rescore * add shallow fusion and lm rescore for streaming zipformer * minor fix * update RESULTS.md * fix yesno workflow, change from ubuntu-18.04 to ubuntu-latest
This commit is contained in:
parent
e32658e620
commit
34d1b07c3d
2
.github/workflows/run-yesno-recipe.yml
vendored
2
.github/workflows/run-yesno-recipe.yml
vendored
@ -35,7 +35,7 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
# os: [ubuntu-18.04, macos-10.15]
|
# os: [ubuntu-18.04, macos-10.15]
|
||||||
# TODO: enable macOS for CPU testing
|
# TODO: enable macOS for CPU testing
|
||||||
os: [ubuntu-18.04]
|
os: [ubuntu-latest]
|
||||||
python-version: [3.8]
|
python-version: [3.8]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -76,6 +76,64 @@ for m in greedy_search modified_beam_search fast_beam_search; do
|
|||||||
--num-decode-streams 2000
|
--num-decode-streams 2000
|
||||||
done
|
done
|
||||||
```
|
```
|
||||||
|
We also support decoding with neural network LMs. After combining with language models, the WERs are
|
||||||
|
| decoding method | chunk size | test-clean | test-other | comment | decoding mode |
|
||||||
|
|----------------------|------------|------------|------------|---------------------|----------------------|
|
||||||
|
| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming |
|
||||||
|
| modified beam search + RNNLM shallow fusion | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming |
|
||||||
|
| modified beam search + RNNLM nbest rescore | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming |
|
||||||
|
|
||||||
|
Please use the following command for RNNLM shallow fusion:
|
||||||
|
```bash
|
||||||
|
for lm_scale in $(seq 0.15 0.01 0.38); do
|
||||||
|
for beam_size in 4 8 12; do
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--beam-size $beam_size \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp-large-LM \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||||
|
--use-shallow-fusion 1 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-exp-dir rnn_lm/exp \
|
||||||
|
--lm-epoch 99 \
|
||||||
|
--lm-scale $lm_scale \
|
||||||
|
--lm-avg 1 \
|
||||||
|
--rnn-lm-embedding-dim 2048 \
|
||||||
|
--rnn-lm-hidden-dim 2048 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--lm-vocab-size 500
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
Please use the following command for RNNLM rescore:
|
||||||
|
```bash
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9 \
|
||||||
|
--use-averaged-model True \
|
||||||
|
--beam-size 8 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method modified_beam_search_lm_rescore \
|
||||||
|
--use-shallow-fusion 0 \
|
||||||
|
--lm-type rnn \
|
||||||
|
--lm-exp-dir rnn_lm/exp \
|
||||||
|
--lm-epoch 99 \
|
||||||
|
--lm-avg 1 \
|
||||||
|
--rnn-lm-embedding-dim 2048 \
|
||||||
|
--rnn-lm-hidden-dim 2048 \
|
||||||
|
--rnn-lm-num-layers 3 \
|
||||||
|
--lm-vocab-size 500
|
||||||
|
```
|
||||||
|
|
||||||
|
A well-trained RNNLM can be found here: <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main>.
|
||||||
|
|
||||||
|
|
||||||
#### Smaller model
|
#### Smaller model
|
||||||
|
|
||||||
|
@ -925,7 +925,6 @@ def main():
|
|||||||
)
|
)
|
||||||
LM.to(device)
|
LM.to(device)
|
||||||
LM.eval()
|
LM.eval()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
LM = None
|
LM = None
|
||||||
|
|
||||||
|
@ -1059,6 +1059,204 @@ def modified_beam_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search_lm_rescore(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
LM: LmScorer,
|
||||||
|
lm_scale_list: List[int],
|
||||||
|
beam: int = 4,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
Rescore the final results with RNNLM and return the one with the highest score
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
|
beam:
|
||||||
|
Number of active paths during the beam search.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
LM:
|
||||||
|
A neural network language model
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
Returns:
|
||||||
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
timestamp=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
) # (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||||
|
|
||||||
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
|
# as index, so we use `to(torch.int64)` below.
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
new_ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
|
|
||||||
|
new_log_prob = topk_log_probs[k]
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
|
|
||||||
|
# get the am_scores for n-best list
|
||||||
|
hyps_shape = get_hyps_shape(B)
|
||||||
|
am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b])
|
||||||
|
am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device)
|
||||||
|
|
||||||
|
# now LM rescore
|
||||||
|
# prepare input data to LM
|
||||||
|
candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b]
|
||||||
|
possible_seqs = k2.RaggedTensor(candidate_seqs)
|
||||||
|
row_splits = possible_seqs.shape.row_splits(1)
|
||||||
|
sentence_token_lengths = row_splits[1:] - row_splits[:-1]
|
||||||
|
possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1)
|
||||||
|
possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1)
|
||||||
|
sentence_token_lengths += 1
|
||||||
|
|
||||||
|
x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
x = x.to(device).to(torch.int64)
|
||||||
|
y = y.to(device).to(torch.int64)
|
||||||
|
sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64)
|
||||||
|
|
||||||
|
lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths)
|
||||||
|
assert lm_scores.ndim == 2
|
||||||
|
lm_scores = -1 * lm_scores.sum(dim=1)
|
||||||
|
|
||||||
|
ans = {}
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
|
||||||
|
# get the best hyp with different lm_scale
|
||||||
|
for lm_scale in lm_scale_list:
|
||||||
|
key = f"nnlm_scale_{lm_scale}"
|
||||||
|
tot_scores = am_scores.values + lm_scores * lm_scale
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax().tolist()
|
||||||
|
unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes]
|
||||||
|
hyps = []
|
||||||
|
for idx in unsorted_indices:
|
||||||
|
hyps.append(unsorted_hyps[idx])
|
||||||
|
|
||||||
|
ans[key] = hyps
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def _deprecated_modified_beam_search(
|
def _deprecated_modified_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
|
@ -122,6 +122,8 @@ from beam_search import (
|
|||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
|
modified_beam_search_lm_rescore,
|
||||||
|
modified_beam_search_lm_shallow_fusion,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
@ -132,6 +134,7 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.lm_wrapper import LmScorer
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
@ -307,6 +310,32 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-shallow-fusion",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Use neural network LM for shallow fusion.
|
||||||
|
If you want to use LODR, you will also need to set this to true
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-type",
|
||||||
|
type=str,
|
||||||
|
default="rnn",
|
||||||
|
help="Type of NN lm",
|
||||||
|
choices=["rnn", "transformer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.3,
|
||||||
|
help="""The scale of the neural network LM
|
||||||
|
Used only when `--use-shallow-fusion` is set to True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -319,6 +348,7 @@ def decode_one_batch(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -443,6 +473,26 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||||
|
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LM=LM,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "modified_beam_search_lm_rescore":
|
||||||
|
lm_scale_list = [0.01 * i for i in range(10, 50)]
|
||||||
|
ans_dict = modified_beam_search_lm_rescore(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LM=LM,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
@ -481,6 +531,13 @@ def decode_one_batch(
|
|||||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
elif params.decoding_method == "modified_beam_search_lm_rescore":
|
||||||
|
ans = dict()
|
||||||
|
assert ans_dict is not None
|
||||||
|
for key, hyps in ans_dict.items():
|
||||||
|
hyps = [sp.decode(hyp).split() for hyp in hyps]
|
||||||
|
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
|
||||||
|
return ans
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
@ -492,6 +549,7 @@ def decode_dataset(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
word_table: Optional[k2.SymbolTable] = None,
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
LM: Optional[LmScorer] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -541,6 +599,7 @@ def decode_dataset(
|
|||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -603,6 +662,7 @@ def save_results(
|
|||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
LmScorer.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
@ -617,6 +677,8 @@ def main():
|
|||||||
"fast_beam_search_nbest_LG",
|
"fast_beam_search_nbest_LG",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
|
"modified_beam_search_lm_shallow_fusion",
|
||||||
|
"modified_beam_search_lm_rescore",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
@ -642,6 +704,14 @@ def main():
|
|||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if params.use_shallow_fusion:
|
||||||
|
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
|
||||||
|
|
||||||
|
if "LODR" in params.decoding_method:
|
||||||
|
params.suffix += (
|
||||||
|
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||||
|
)
|
||||||
|
|
||||||
if params.use_averaged_model:
|
if params.use_averaged_model:
|
||||||
params.suffix += "-use-averaged-model"
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
@ -751,6 +821,19 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# only load the neural network LM if required
|
||||||
|
if params.use_shallow_fusion or "lm" in params.decoding_method:
|
||||||
|
LM = LmScorer(
|
||||||
|
lm_type=params.lm_type,
|
||||||
|
params=params,
|
||||||
|
device=device,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
)
|
||||||
|
LM.to(device)
|
||||||
|
LM.eval()
|
||||||
|
else:
|
||||||
|
LM = None
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
@ -792,6 +875,7 @@ def main():
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
word_table=word_table,
|
word_table=word_table,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
LM=LM,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -154,17 +154,18 @@ class RnnLmModel(torch.nn.Module):
|
|||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
|
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
|
||||||
"""Score a batch of tokens
|
"""Score a batch of tokens, i.e each sample in the batch should be a
|
||||||
|
single token. For example, x = torch.tensor([[5],[10],[20]])
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor):
|
x (torch.Tensor):
|
||||||
A batch of tokens
|
A batch of tokens
|
||||||
x_lens (torch.Tensor):
|
x_lens (torch.Tensor):
|
||||||
The length of tokens in the batch before padding
|
The length of tokens in the batch before padding
|
||||||
state (_type_, optional):
|
state (optional):
|
||||||
Either None or a tuple of two torch.Tensor. Each tensor has
|
Either None or a tuple of two torch.Tensor. Each tensor has
|
||||||
the shape of (hidden_dim)
|
the shape of (num_layers, bs, hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
_type_: _description_
|
_type_: _description_
|
||||||
|
Loading…
x
Reference in New Issue
Block a user