mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Use an n-gram LM to rescore the lattice from fast_beam_search.
This commit is contained in:
parent
2d7096dfc6
commit
9ffc77a0f2
@ -1,8 +1,8 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
# This script downloads the test-clean and test-other datasets
|
# This script downloads the test-clean and test-other datasets
|
||||||
# of LibriSpeech and unzip them to the folder ~/tmp/download,
|
# of LibriSpeech and unzips them to the folder ~/tmp/download,
|
||||||
# which is cached by GitHub actions for later runs.
|
# which are cached by GitHub actions for later runs.
|
||||||
#
|
#
|
||||||
# You will find directories ~/tmp/download/LibriSpeech after running
|
# You will find directories ~/tmp/download/LibriSpeech after running
|
||||||
# this script.
|
# this script.
|
||||||
|
@ -10,6 +10,25 @@ During training, it selects either a batch from GigaSpeech with prob `giga_prob`
|
|||||||
or a batch from LibriSpeech with prob `1 - giga_prob`. All utterances within
|
or a batch from LibriSpeech with prob `1 - giga_prob`. All utterances within
|
||||||
a batch come from the same dataset.
|
a batch come from the same dataset.
|
||||||
|
|
||||||
|
#### 2022-05-10
|
||||||
|
|
||||||
|
Using commit `TODO`.
|
||||||
|
|
||||||
|
The WERs are:
|
||||||
|
|
||||||
|
| | test-clean | test-other | comment |
|
||||||
|
|-------------------------------------|------------|------------|----------------------------------------|
|
||||||
|
| greedy search (max sym per frame 1) | 2.21 | 5.09 | --epoch 27 --avg 2 --max-duration 600 |
|
||||||
|
| greedy search (max sym per frame 1) | 2.25 | 5.02 | --epoch 27 --avg 12 --max-duration 600 |
|
||||||
|
| modified beam search | 2.19 | 5.03 | --epoch 25 --avg 6 --max-duration 600 |
|
||||||
|
| modified beam search | 2.23 | 4.94 | --epoch 27 --avg 10 --max-duration 600 |
|
||||||
|
| beam search | 2.16 | 4.95 | --epoch 25 --avg 7 --max-duration 600 |
|
||||||
|
| fast beam search | 2.21 | 4.96 | --epoch 27 --avg 10 --max-duration 600 |
|
||||||
|
| fast beam search | 2.19 | 4.97 | --epoch 27 --avg 12 --max-duration 600 |
|
||||||
|
|
||||||
|
|
||||||
|
#### 2022-04-29
|
||||||
|
|
||||||
Using commit `ac84220de91dee10c00e8f4223287f937b1930b6`.
|
Using commit `ac84220de91dee10c00e8f4223287f937b1930b6`.
|
||||||
|
|
||||||
See <https://github.com/k2-fsa/icefall/pull/312>.
|
See <https://github.com/k2-fsa/icefall/pull/312>.
|
||||||
|
@ -19,6 +19,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
@ -34,10 +35,11 @@ def fast_beam_search_one_best(
|
|||||||
beam: float,
|
beam: float,
|
||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using modified beam search, and then
|
A lattice is first obtained using fast beam search, and then
|
||||||
the shortest path within the lattice is used as the final output.
|
the shortest path within the lattice is used as the final output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -56,6 +58,8 @@ def fast_beam_search_one_best(
|
|||||||
Max states per stream per frame.
|
Max states per stream per frame.
|
||||||
max_contexts:
|
max_contexts:
|
||||||
Max contexts pre stream per frame.
|
Max contexts pre stream per frame.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -67,6 +71,7 @@ def fast_beam_search_one_best(
|
|||||||
beam=beam,
|
beam=beam,
|
||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
best_path = one_best_decoding(lattice)
|
best_path = one_best_decoding(lattice)
|
||||||
@ -74,6 +79,85 @@ def fast_beam_search_one_best(
|
|||||||
return hyps
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_nbest(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
num_paths: int,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using fast beam search, and then
|
||||||
|
we extract `num_paths` from the lattice using k2.random_path(),
|
||||||
|
unique them, compute the total score of each path by intersecting
|
||||||
|
it with the lattice, and output the path with the largest total score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
# at this point, nbest.fsa.scores are all zeros.
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa.scores contains acoustic scores
|
||||||
|
|
||||||
|
max_indexes = nbest.tot_scores().argmax()
|
||||||
|
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search_nbest_oracle(
|
def fast_beam_search_nbest_oracle(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
decoding_graph: k2.Fsa,
|
decoding_graph: k2.Fsa,
|
||||||
@ -86,10 +170,11 @@ def fast_beam_search_nbest_oracle(
|
|||||||
ref_texts: List[List[int]],
|
ref_texts: List[List[int]],
|
||||||
use_double_scores: bool = True,
|
use_double_scores: bool = True,
|
||||||
nbest_scale: float = 0.5,
|
nbest_scale: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
A lattice is first obtained using modified beam search, and then
|
A lattice is first obtained using fast beam search, and then
|
||||||
we select `num_paths` linear paths from the lattice. The path
|
we select `num_paths` linear paths from the lattice. The path
|
||||||
that has the minimum edit distance with the given reference transcript
|
that has the minimum edit distance with the given reference transcript
|
||||||
is used as the output.
|
is used as the output.
|
||||||
@ -125,6 +210,8 @@ def fast_beam_search_nbest_oracle(
|
|||||||
nbest_scale:
|
nbest_scale:
|
||||||
It's the scale applied to the lattice.scores. A smaller value
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
yields more unique paths.
|
yields more unique paths.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
@ -137,6 +224,7 @@ def fast_beam_search_nbest_oracle(
|
|||||||
beam=beam,
|
beam=beam,
|
||||||
max_states=max_states,
|
max_states=max_states,
|
||||||
max_contexts=max_contexts,
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
nbest = Nbest.from_lattice(
|
nbest = Nbest.from_lattice(
|
||||||
@ -169,6 +257,158 @@ def fast_beam_search_nbest_oracle(
|
|||||||
return hyps
|
return hyps
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_with_nbest_rescoring(
|
||||||
|
model: Transducer,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
ngram_lm_scale_list: List[float],
|
||||||
|
num_paths: int,
|
||||||
|
G: k2.Fsa,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
oov_word: str = "<UNK>",
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
nbest_scale: float = 0.5,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> Dict[str, List[List[int]]]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using modified beam search, and then
|
||||||
|
the shortest path within the lattice is used as the final output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
ngram_lm_scale_list:
|
||||||
|
A list of floats representing LM score scales.
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the decoded lattice.
|
||||||
|
G:
|
||||||
|
An FsaVec containing only a single FSA. It is an n-gram LM.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
oov_word:
|
||||||
|
OOV words are replaced with this word.
|
||||||
|
use_double_scores:
|
||||||
|
True to use double precision for computation. False to use
|
||||||
|
single precision.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
|
yields more unique paths.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result in a dict, where the key has the form
|
||||||
|
'ngram_lm_scale_xx' and the value is the decoded results. `xx` is the
|
||||||
|
ngram LM scale value used during decoding, i.e., 0.1.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
nbest = Nbest.from_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=num_paths,
|
||||||
|
use_double_scores=use_double_scores,
|
||||||
|
nbest_scale=nbest_scale,
|
||||||
|
)
|
||||||
|
# at this point, nbest.fsa.scores are all zeros.
|
||||||
|
|
||||||
|
nbest = nbest.intersect(lattice)
|
||||||
|
# Now nbest.fsa.scores contains acoustic scores
|
||||||
|
|
||||||
|
am_scores = nbest.tot_scores()
|
||||||
|
|
||||||
|
# Now we need to compute the LM scores of each path.
|
||||||
|
# (1) Get the token IDs of each Path. We assume the decoding_graph
|
||||||
|
# is an acceptor, i.e., lattice is also an acceptor
|
||||||
|
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc]
|
||||||
|
|
||||||
|
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous())
|
||||||
|
tokens = tokens.remove_values_leq(0) # remove -1 and 0
|
||||||
|
|
||||||
|
token_list: List[List[int]] = tokens.tolist()
|
||||||
|
word_list: List[List[str]] = sp.decode(token_list)
|
||||||
|
|
||||||
|
assert isinstance(oov_word, str), oov_word
|
||||||
|
assert oov_word in word_table, oov_word
|
||||||
|
oov_word_id = word_table[oov_word]
|
||||||
|
|
||||||
|
word_ids_list: List[List[int]] = []
|
||||||
|
for words in word_list:
|
||||||
|
this_word_ids = []
|
||||||
|
for w in words:
|
||||||
|
if w in word_table:
|
||||||
|
this_word_ids.append(word_table[w])
|
||||||
|
else:
|
||||||
|
this_word_ids.append(oov_word_id)
|
||||||
|
word_ids_list.append(this_word_ids)
|
||||||
|
|
||||||
|
word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device)
|
||||||
|
word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas)
|
||||||
|
|
||||||
|
num_unique_paths = len(word_ids_list)
|
||||||
|
|
||||||
|
b_to_a_map = torch.zeros(
|
||||||
|
num_unique_paths,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=lattice.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
rescored_word_fsas = k2.intersect_device(
|
||||||
|
a_fsas=G,
|
||||||
|
b_fsas=word_fsas_with_self_loops,
|
||||||
|
b_to_a_map=b_to_a_map,
|
||||||
|
sorted_match_a=True,
|
||||||
|
ret_arc_maps=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas))
|
||||||
|
ngram_lm_scores = rescored_word_fsas.get_tot_scores(
|
||||||
|
use_double_scores=True,
|
||||||
|
log_semiring=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ans: Dict[str, List[List[int]]] = {}
|
||||||
|
for s in ngram_lm_scale_list:
|
||||||
|
key = f"ngram_lm_scale_{s}"
|
||||||
|
tot_scores = am_scores.values + s * ngram_lm_scores
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
|
||||||
|
ans[key] = hyps
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def fast_beam_search(
|
def fast_beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
decoding_graph: k2.Fsa,
|
decoding_graph: k2.Fsa,
|
||||||
@ -177,6 +417,7 @@ def fast_beam_search(
|
|||||||
beam: float,
|
beam: float,
|
||||||
max_states: int,
|
max_states: int,
|
||||||
max_contexts: int,
|
max_contexts: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> k2.Fsa:
|
) -> k2.Fsa:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -196,6 +437,8 @@ def fast_beam_search(
|
|||||||
Max states per stream per frame.
|
Max states per stream per frame.
|
||||||
max_contexts:
|
max_contexts:
|
||||||
Max contexts pre stream per frame.
|
Max contexts pre stream per frame.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
Returns:
|
Returns:
|
||||||
Return an FsaVec with axes [utt][state][arc] containing the decoded
|
Return an FsaVec with axes [utt][state][arc] containing the decoded
|
||||||
lattice. Note: When the input graph is a TrivialGraph, the returned
|
lattice. Note: When the input graph is a TrivialGraph, the returned
|
||||||
@ -244,7 +487,7 @@ def fast_beam_search(
|
|||||||
project_input=False,
|
project_input=False,
|
||||||
)
|
)
|
||||||
logits = logits.squeeze(1).squeeze(1)
|
logits = logits.squeeze(1).squeeze(1)
|
||||||
log_probs = logits.log_softmax(dim=-1)
|
log_probs = (logits / temperature).log_softmax(dim=-1)
|
||||||
decoding_streams.advance(log_probs)
|
decoding_streams.advance(log_probs)
|
||||||
decoding_streams.terminate_and_flush_to_streams()
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
@ -587,6 +830,7 @@ def modified_beam_search(
|
|||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
@ -600,6 +844,8 @@ def modified_beam_search(
|
|||||||
encoder_out before padding.
|
encoder_out before padding.
|
||||||
beam:
|
beam:
|
||||||
Number of active paths during the beam search.
|
Number of active paths during the beam search.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
Returns:
|
Returns:
|
||||||
Return a list-of-list of token IDs. ans[i] is the decoding results
|
Return a list-of-list of token IDs. ans[i] is the decoding results
|
||||||
for the i-th utterance.
|
for the i-th utterance.
|
||||||
@ -683,7 +929,7 @@ def modified_beam_search(
|
|||||||
|
|
||||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
log_probs = (logits / temperature).log_softmax(dim=-1)
|
||||||
|
|
||||||
log_probs.add_(ys_log_probs)
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
@ -847,6 +1093,7 @@ def beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||||
@ -860,6 +1107,8 @@ def beam_search(
|
|||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
beam:
|
beam:
|
||||||
Beam size.
|
Beam size.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -936,7 +1185,7 @@ def beam_search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Scale the blank posterior
|
# TODO(fangjun): Scale the blank posterior
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
log_prob = (logits / temperature).log_softmax(dim=-1)
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
# log_prob is (1, 1, 1, vocab_size)
|
||||||
log_prob = log_prob.squeeze()
|
log_prob = log_prob.squeeze()
|
||||||
# Now log_prob is (vocab_size,)
|
# Now log_prob is (vocab_size,)
|
||||||
|
@ -19,40 +19,67 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search (one best)
|
||||||
./pruned_transducer_stateless3/decode.py \
|
./pruned_transducer_stateless3/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless3/exp \
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest with n-gram LM rescoring)
|
||||||
|
./pruned_transducer_stateless3/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless3/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method fast_beam_search_with_nbest_rescoring \
|
||||||
|
--beam 4 \
|
||||||
|
--max-contexts 4 \
|
||||||
|
--max-states 8 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--nbest-scale 0.5 \
|
||||||
|
--lm-dir ./data/lm
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -69,8 +96,10 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_oracle,
|
fast_beam_search_nbest_oracle,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
|
fast_beam_search_with_nbest_rescoring,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -147,7 +176,9 @@ def get_parser():
|
|||||||
- beam_search
|
- beam_search
|
||||||
- modified_beam_search
|
- modified_beam_search
|
||||||
- fast_beam_search
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
- fast_beam_search_nbest_oracle
|
- fast_beam_search_nbest_oracle
|
||||||
|
- fast_beam_search_with_nbest_rescoring
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -168,7 +199,9 @@ def get_parser():
|
|||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
Used only when --decoding-method is
|
Used only when --decoding-method is
|
||||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -176,7 +209,9 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -184,7 +219,9 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="""Used only when --decoding-method is
|
help="""Used only when --decoding-method is
|
||||||
fast_beam_search or fast_beam_search_nbest_oracle""",
|
fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -194,6 +231,7 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
@ -207,7 +245,8 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=100,
|
default=100,
|
||||||
help="""Number of paths for computed nbest oracle WER
|
help="""Number of paths for computed nbest oracle WER
|
||||||
when the decoding method is fast_beam_search_nbest_oracle.
|
when the decoding method is fast_beam_search_nbest_oracle,
|
||||||
|
fast_beam_search_nbest, or fast_beam_search_with_nbest_rescoring.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -216,9 +255,40 @@ def get_parser():
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help="""Scale applied to lattice scores when computing nbest paths.
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
Used only when the decoding_method is fast_beam_search_nbest_oracle.
|
Used only when the decoding_method is fast_beam_search_nbest_oracle,
|
||||||
|
fast_beam_search_nbest, or fast_beam_search_with_nbest_rescoring.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="""Softmax temperature.
|
||||||
|
The output of the model is (logits / temperature).log_softmax().
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("./data/lm"),
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search_with_nbest_rescoring.
|
||||||
|
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--words-txt",
|
||||||
|
type=Path,
|
||||||
|
default=Path("./data/lang_bpe_500/words.txt"),
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search_with_nbest_rescoring.
|
||||||
|
It is the word table.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -228,6 +298,8 @@ def decode_one_batch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -252,8 +324,17 @@ def decode_one_batch(
|
|||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is
|
only when decoding method is fast_beam_search, fast_beam_search_nbest,
|
||||||
fast_beam_search or fast_beam_search_nbest_oracle.
|
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring.
|
||||||
|
G:
|
||||||
|
Optional. Used only when decoding method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
|
||||||
|
or fast_beam_search_with_nbest_rescoring.
|
||||||
|
It an FsaVec containing an acceptor.
|
||||||
|
word_table:
|
||||||
|
Optional. Used only when decoding method is
|
||||||
|
fast_beam_search_with_nbest_rescoring. It is the word symbol table
|
||||||
|
containing mappings between words and IDs.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -282,6 +363,22 @@ def decode_one_batch(
|
|||||||
beam=params.beam,
|
beam=params.beam,
|
||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
|
temperature=params.temperature,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
temperature=params.temperature,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -297,9 +394,30 @@ def decode_one_batch(
|
|||||||
num_paths=params.num_paths,
|
num_paths=params.num_paths,
|
||||||
ref_texts=sp.encode(supervisions["text"]),
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
nbest_scale=params.nbest_scale,
|
nbest_scale=params.nbest_scale,
|
||||||
|
temperature=params.temperature,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
||||||
|
ngram_lm_scale_list = [-0.3, -0.2, -0.1, -0.05, -0.02, 0]
|
||||||
|
ngram_lm_scale_list += [0.01, 0.02, 0.05]
|
||||||
|
hyp_tokens = fast_beam_search_with_nbest_rescoring(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_states=params.max_states,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
ngram_lm_scale_list=ngram_lm_scale_list,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
G=G,
|
||||||
|
sp=sp,
|
||||||
|
word_table=word_table,
|
||||||
|
use_double_scores=True,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
temperature=params.temperature,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
params.decoding_method == "greedy_search"
|
params.decoding_method == "greedy_search"
|
||||||
and params.max_sym_per_frame == 1
|
and params.max_sym_per_frame == 1
|
||||||
@ -317,6 +435,7 @@ def decode_one_batch(
|
|||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
temperature=params.temperature,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -338,6 +457,7 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out_i,
|
encoder_out=encoder_out_i,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
|
temperature=params.temperature,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -352,7 +472,19 @@ def decode_one_batch(
|
|||||||
(
|
(
|
||||||
f"beam_{params.beam}_"
|
f"beam_{params.beam}_"
|
||||||
f"max_contexts_{params.max_contexts}_"
|
f"max_contexts_{params.max_contexts}_"
|
||||||
f"max_states_{params.max_states}"
|
f"max_states_{params.max_states}_"
|
||||||
|
f"temperature_{params.temperature}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
return {
|
||||||
|
(
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}_"
|
||||||
|
f"num_paths_{params.num_paths}_"
|
||||||
|
f"nbest_scale_{params.nbest_scale}_"
|
||||||
|
f"temperature_{params.temperature}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
@ -362,11 +494,31 @@ def decode_one_batch(
|
|||||||
f"max_contexts_{params.max_contexts}_"
|
f"max_contexts_{params.max_contexts}_"
|
||||||
f"max_states_{params.max_states}_"
|
f"max_states_{params.max_states}_"
|
||||||
f"num_paths_{params.num_paths}_"
|
f"num_paths_{params.num_paths}_"
|
||||||
f"nbest_scale_{params.nbest_scale}"
|
f"nbest_scale_{params.nbest_scale}_"
|
||||||
|
f"temperature_{params.temperature}"
|
||||||
): hyps
|
): hyps
|
||||||
}
|
}
|
||||||
|
elif params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
||||||
|
prefix = (
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}_"
|
||||||
|
f"num_paths_{params.num_paths}_"
|
||||||
|
f"nbest_scale_{params.nbest_scale}_"
|
||||||
|
f"temperature_{params.temperature}_"
|
||||||
|
)
|
||||||
|
ans: Dict[str, List[List[str]]] = {}
|
||||||
|
for key, hyp in hyp_tokens.items():
|
||||||
|
t: List[str] = sp.decode(hyp)
|
||||||
|
ans[prefix + key] = [s.split() for s in t]
|
||||||
|
return ans
|
||||||
else:
|
else:
|
||||||
return {f"beam_size_{params.beam_size}": hyps}
|
return {
|
||||||
|
(
|
||||||
|
f"beam_size_{params.beam_size}_"
|
||||||
|
f"temperature_{params.temperature}"
|
||||||
|
): hyps
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -375,6 +527,8 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
decoding_graph: Optional[k2.Fsa] = None,
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -389,7 +543,17 @@ def decode_dataset(
|
|||||||
The BPE model.
|
The BPE model.
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
only when --decoding_method is fast_beam_search.
|
only when decoding method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring.
|
||||||
|
G:
|
||||||
|
Optional. Used only when decoding method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, fast_beam_search_nbest_oracle,
|
||||||
|
or fast_beam_search_with_nbest_rescoring.
|
||||||
|
It an FsaVec containing an acceptor.
|
||||||
|
word_table:
|
||||||
|
Optional. Used only when decoding method is
|
||||||
|
fast_beam_search_with_nbest_rescoring. It is the word symbol table
|
||||||
|
containing mappings between words and IDs.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -419,6 +583,8 @@ def decode_dataset(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
G=G,
|
||||||
|
word_table=word_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -438,6 +604,7 @@ def decode_dataset(
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -485,6 +652,68 @@ def save_results(
|
|||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
def load_ngram_LM(
|
||||||
|
lm_dir: Path, word_table: k2.SymbolTable, device: torch.device
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Read a ngram model from the given directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lm_dir:
|
||||||
|
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||||
|
word_table:
|
||||||
|
The word table mapping words to IDs and vice versa.
|
||||||
|
device:
|
||||||
|
The resulting FSA will be moved to this device.
|
||||||
|
Returns:
|
||||||
|
Return an FsaVec containing a single acceptor.
|
||||||
|
"""
|
||||||
|
lm_dir = Path(lm_dir)
|
||||||
|
assert lm_dir.is_dir(), f"{lm_dir} does not exist"
|
||||||
|
|
||||||
|
pt_file = lm_dir / "G_4_gram.pt"
|
||||||
|
|
||||||
|
if pt_file.is_file():
|
||||||
|
logging.info(f"Loading pre-compiled {pt_file}")
|
||||||
|
d = torch.load(pt_file, map_location=device)
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
return G
|
||||||
|
|
||||||
|
txt_file = lm_dir / "G_4_gram.fst.txt"
|
||||||
|
|
||||||
|
assert txt_file.is_file(), f"{txt_file} does not exist"
|
||||||
|
logging.info(f"Loading {txt_file}")
|
||||||
|
logging.warning("It may take 8 minutes (Will be cached for later use).")
|
||||||
|
with open(txt_file) as f:
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
|
||||||
|
# G.aux_labels is not needed in later computations, so
|
||||||
|
# remove it here.
|
||||||
|
del G.aux_labels
|
||||||
|
# Now G is an acceptor
|
||||||
|
|
||||||
|
first_word_disambig_id = word_table["#0"]
|
||||||
|
# CAUTION: The following line is crucial.
|
||||||
|
# Arcs entering the back-off state have label equal to #0.
|
||||||
|
# We have to change it to 0 here.
|
||||||
|
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||||
|
|
||||||
|
# See https://github.com/k2-fsa/k2/issues/874
|
||||||
|
# for why we need to set G.properties to None
|
||||||
|
G.__dict__["_properties"] = None
|
||||||
|
|
||||||
|
G = k2.Fsa.from_fsas([G]).to(device)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
# Save a dummy value so that it can be loaded in C++.
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/67902
|
||||||
|
# for why we need to do this.
|
||||||
|
G.dummy = 1
|
||||||
|
|
||||||
|
logging.info(f"Saving to {pt_file} for later use")
|
||||||
|
torch.save(G.as_dict(), pt_file)
|
||||||
|
return G
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
@ -499,7 +728,9 @@ def main():
|
|||||||
"greedy_search",
|
"greedy_search",
|
||||||
"beam_search",
|
"beam_search",
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
|
"fast_beam_search_with_nbest_rescoring",
|
||||||
"modified_beam_search",
|
"modified_beam_search",
|
||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
@ -513,19 +744,27 @@ def main():
|
|||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
params.suffix += f"-temperature-{params.temperature}"
|
||||||
|
elif params.decoding_method in (
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
|
"fast_beam_search_with_nbest_rescoring",
|
||||||
|
):
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
params.suffix += f"-num-paths-{params.num_paths}"
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
|
params.suffix += f"-temperature-{params.temperature}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += (
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
)
|
)
|
||||||
|
params.suffix += f"-temperature-{params.temperature}"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
params.suffix += f"-temperature-{params.temperature}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -585,12 +824,26 @@ def main():
|
|||||||
|
|
||||||
if params.decoding_method in (
|
if params.decoding_method in (
|
||||||
"fast_beam_search",
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
"fast_beam_search_nbest_oracle",
|
"fast_beam_search_nbest_oracle",
|
||||||
|
"fast_beam_search_with_nbest_rescoring",
|
||||||
):
|
):
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
else:
|
else:
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search_with_nbest_rescoring":
|
||||||
|
logging.info(f"Loading word symbol table from {params.words_txt}")
|
||||||
|
word_table = k2.SymbolTable.from_file(params.words_txt)
|
||||||
|
G = load_ngram_LM(
|
||||||
|
lm_dir=params.lm_dir,
|
||||||
|
word_table=word_table,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
word_table = None
|
||||||
|
G = None
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
@ -613,6 +866,8 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
|
G=G,
|
||||||
|
word_table=word_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -308,9 +308,7 @@ class Nbest(object):
|
|||||||
del word_fsa.aux_labels
|
del word_fsa.aux_labels
|
||||||
|
|
||||||
word_fsa.scores.zero_()
|
word_fsa.scores.zero_()
|
||||||
word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops(
|
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
||||||
word_fsa
|
|
||||||
)
|
|
||||||
|
|
||||||
path_to_utt_map = self.shape.row_ids(1)
|
path_to_utt_map = self.shape.row_ids(1)
|
||||||
|
|
||||||
@ -609,7 +607,7 @@ def rescore_with_n_best_list(
|
|||||||
num_paths:
|
num_paths:
|
||||||
Size of nbest list.
|
Size of nbest list.
|
||||||
lm_scale_list:
|
lm_scale_list:
|
||||||
A list of float representing LM score scales.
|
A list of floats representing LM score scales.
|
||||||
nbest_scale:
|
nbest_scale:
|
||||||
Scale to be applied to ``lattice.score`` when sampling paths
|
Scale to be applied to ``lattice.score`` when sampling paths
|
||||||
using ``k2.random_paths``.
|
using ``k2.random_paths``.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user