mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
WIP: Use shallow fusion in modified beam search.
This commit is contained in:
parent
27fa5f05d3
commit
954b4efff3
160
egs/librispeech/ASR/local/compile_lg.py
Executable file
160
egs/librispeech/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,160 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input lang_dir and generates LG from
|
||||||
|
|
||||||
|
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||||
|
|
||||||
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
|
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||||
|
|
||||||
|
The generated LG is saved in $lang_dir/LG.fst
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lang_dir:
|
||||||
|
The language directory, e.g., data/lang_phone or data/lang_bpe_500.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
An FST representing LG.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokens = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
|
||||||
|
|
||||||
|
assert "#0" in tokens
|
||||||
|
|
||||||
|
first_token_disambig_id = tokens["#0"]
|
||||||
|
logging.info(f"first token disambig ID: {first_token_disambig_id}")
|
||||||
|
|
||||||
|
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||||
|
|
||||||
|
if Path("data/lm/G_3_gram.pt").is_file():
|
||||||
|
logging.info("Loading pre-compiled G_3_gram")
|
||||||
|
d = torch.load("data/lm/G_3_gram.pt")
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
else:
|
||||||
|
logging.info("Loading G_3_gram.fst.txt")
|
||||||
|
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
del G.aux_labels
|
||||||
|
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||||
|
|
||||||
|
L = k2.arc_sort(L)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
logging.info("Composing L and G")
|
||||||
|
LG = k2.compose(L, G)
|
||||||
|
logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}")
|
||||||
|
|
||||||
|
del LG.aux_labels
|
||||||
|
|
||||||
|
logging.info("Connecting LG")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
logging.info(
|
||||||
|
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Determinizing LG")
|
||||||
|
LG = k2.determinize(LG)
|
||||||
|
logging.info(
|
||||||
|
f"LG shape after k2.determinize: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Connecting LG after k2.determinize")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
logging.info(
|
||||||
|
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Removing disambiguation symbols on LG")
|
||||||
|
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
|
# See https://github.com/k2-fsa/k2/issues/874
|
||||||
|
# for why we need to set LG.properties to None
|
||||||
|
LG.__dict__["_properties"] = None
|
||||||
|
|
||||||
|
logging.info("Removing epsilons")
|
||||||
|
LG = k2.remove_epsilon(LG)
|
||||||
|
logging.info(
|
||||||
|
f"LG shape after k2.remove_epsilon: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Connecting")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
logging.info(
|
||||||
|
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Arc sorting LG")
|
||||||
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
|
logging.info(f"LG properties: {LG.properties_str}")
|
||||||
|
# Possible properties is:
|
||||||
|
# "Valid|Nonempty|ArcSorted|EpsilonFree|MaybeAccessible|MaybeCoaccessible"
|
||||||
|
logging.info("Caution: LG is not deterministic!!!")
|
||||||
|
|
||||||
|
return LG
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
out_filename = lang_dir / "LG.pt"
|
||||||
|
|
||||||
|
if out_filename.is_file():
|
||||||
|
logging.info(f"{out_filename} already exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Processing {lang_dir}")
|
||||||
|
|
||||||
|
LG = compile_LG(lang_dir)
|
||||||
|
logging.info(f"Saving LG to {out_filename}")
|
||||||
|
torch.save(LG.as_dict(), out_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
79
egs/librispeech/ASR/local/test_compile_lg.py
Executable file
79
egs/librispeech/ASR/local/test_compile_lg.py
Executable file
@ -0,0 +1,79 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./local/test_compile_lg.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
|
||||||
|
lang_dir = Path("./data/lang_bpe_500")
|
||||||
|
|
||||||
|
|
||||||
|
def get_word_ids(word_table: k2.SymbolTable, s: str) -> List[int]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
word_table:
|
||||||
|
Word symbol table.
|
||||||
|
s:
|
||||||
|
A string consisting of space(s) separated words.
|
||||||
|
Returns:
|
||||||
|
Return a list of word IDs.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for w in s.split():
|
||||||
|
ans.append(word_table[w])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
assert lang_dir.exists(), f"{lang_dir} does not exist!"
|
||||||
|
LG = k2.Fsa.from_dict(torch.load(f"{lang_dir}/LG.pt", map_location="cpu"))
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(f"{lang_dir}/bpe.model")
|
||||||
|
|
||||||
|
word_table = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
||||||
|
s = "HELLO WORLD"
|
||||||
|
token_ids = sp.encode(s)
|
||||||
|
|
||||||
|
token_fsa = k2.linear_fsa(token_ids)
|
||||||
|
|
||||||
|
fsa = k2.intersect(LG, token_fsa)
|
||||||
|
fsa = k2.connect(fsa)
|
||||||
|
print(k2.to_dot(fsa))
|
||||||
|
print(fsa.properties_str)
|
||||||
|
print(LG.properties_str)
|
||||||
|
# You can use https://dreampuf.github.io/GraphvizOnline/
|
||||||
|
# to visualize the output.
|
||||||
|
#
|
||||||
|
# You can see that the resulting fsa is not deterministic
|
||||||
|
# Note: LG is non-deterministic
|
||||||
|
#
|
||||||
|
# See https://shorturl.at/uIL69
|
||||||
|
# for visualization of the above fsa.
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -17,8 +17,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
from shallow_fusion import shallow_fusion
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(
|
def greedy_search(
|
||||||
@ -111,6 +113,13 @@ class Hypothesis:
|
|||||||
# It contains only one entry.
|
# It contains only one entry.
|
||||||
log_prob: torch.Tensor
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
|
# Used for shallow fusion
|
||||||
|
# The key of the dict is a state index into LG
|
||||||
|
# while the corresponding value is the LM score
|
||||||
|
# reaching this state.
|
||||||
|
# Note: The value tensor contains only a single entry
|
||||||
|
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = (None,)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
"""Return a string representation of self.ys"""
|
"""Return a string representation of self.ys"""
|
||||||
@ -149,6 +158,15 @@ class HypothesisList(object):
|
|||||||
torch.logaddexp(
|
torch.logaddexp(
|
||||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||||
)
|
)
|
||||||
|
if hyp.ngram_state_and_scores is not None:
|
||||||
|
for state, score in hyp.ngram_state_and_scores.items():
|
||||||
|
if (
|
||||||
|
state in old_hyp.ngram_state_and_scores
|
||||||
|
and score > old_hyp.ngram_state_and_scores[state]
|
||||||
|
):
|
||||||
|
old_hyp.ngram_state_and_scores[state] = score
|
||||||
|
else:
|
||||||
|
old_hyp.ngram_state_and_scores[state] = score
|
||||||
else:
|
else:
|
||||||
self._data[key] = hyp
|
self._data[key] = hyp
|
||||||
|
|
||||||
@ -318,6 +336,7 @@ def modified_beam_search(
|
|||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
beam: int = 4,
|
beam: int = 4,
|
||||||
|
LG: Optional[k2.Fsa] = None,
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
"""It limits the maximum number of symbols per frame to 1.
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
@ -328,9 +347,13 @@ def modified_beam_search(
|
|||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
beam:
|
beam:
|
||||||
Beam size.
|
Beam size.
|
||||||
|
LG:
|
||||||
|
Optional. Used for shallow fusion.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
|
enable_shallow_fusion = LG is not None
|
||||||
|
ngram_lm_scale = 0.8
|
||||||
|
|
||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
@ -350,10 +373,19 @@ def modified_beam_search(
|
|||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
|
|
||||||
|
if enable_shallow_fusion:
|
||||||
|
ngram_state_and_scores = {
|
||||||
|
0: torch.zeros(1, dtype=torch.float32, device=device)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
ngram_state_and_scores = None
|
||||||
|
|
||||||
B.add(
|
B.add(
|
||||||
Hypothesis(
|
Hypothesis(
|
||||||
ys=[blank_id] * context_size,
|
ys=[blank_id] * context_size,
|
||||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
ngram_state_and_scores=ngram_state_and_scores,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -411,9 +443,33 @@ def modified_beam_search(
|
|||||||
new_token = topk_token_indexes[i]
|
new_token = topk_token_indexes[i]
|
||||||
if new_token != blank_id:
|
if new_token != blank_id:
|
||||||
new_ys.append(new_token)
|
new_ys.append(new_token)
|
||||||
|
else:
|
||||||
|
ngram_state_and_scores = hyp.ngram_state_and_scores
|
||||||
|
|
||||||
new_log_prob = topk_log_probs[i]
|
new_log_prob = topk_log_probs[i]
|
||||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
|
||||||
|
if enable_shallow_fusion and new_token != blank_id:
|
||||||
|
ngram_state_and_scores = shallow_fusion(
|
||||||
|
LG, new_token, hyp.ngram_state_and_scores
|
||||||
|
)
|
||||||
|
if len(ngram_state_and_scores) == 0:
|
||||||
|
continue
|
||||||
|
max_ngram_score = max(ngram_state_and_scores.values())
|
||||||
|
new_log_prob += ngram_lm_scale * max_ngram_score
|
||||||
|
|
||||||
|
# TODO: Get the maximum scores in ngram_state_and_scores
|
||||||
|
# and add it to new_log_prob
|
||||||
|
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys,
|
||||||
|
log_prob=new_log_prob,
|
||||||
|
ngram_state_and_scores=ngram_state_and_scores,
|
||||||
|
)
|
||||||
|
|
||||||
B.add(new_hyp)
|
B.add(new_hyp)
|
||||||
|
if len(B) == 0:
|
||||||
|
for h in A:
|
||||||
|
B.add(h)
|
||||||
|
|
||||||
best_hyp = B.get_most_probable(length_norm=True)
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
@ -40,8 +40,9 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -131,6 +132,13 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--LG",
|
||||||
|
type=str,
|
||||||
|
help="""Path to LG.pt for shallow fusion.
|
||||||
|
Used only when --decoding-method is modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -203,6 +211,7 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
|
LG: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[str]]]:
|
) -> Dict[str, List[List[str]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
@ -225,6 +234,9 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
|
LG:
|
||||||
|
Optional. Used for shallow fusion. Used only when params.decoding_method
|
||||||
|
is modified_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
@ -257,17 +269,24 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
hyp = modified_beam_search(
|
hyp = modified_beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
LG=LG,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
)
|
)
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
for h in hyps:
|
||||||
|
print(" ".join(h))
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
@ -280,6 +299,7 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
LG: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
@ -292,6 +312,9 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
sp:
|
sp:
|
||||||
The BPE model.
|
The BPE model.
|
||||||
|
LG:
|
||||||
|
Optional. Used for shallow fusion. Used only when params.decoding_method
|
||||||
|
is modified_beam_search.
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "greedy_search" if greedy search
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
@ -320,6 +343,7 @@ def decode_dataset(
|
|||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
LG=LG,
|
||||||
)
|
)
|
||||||
|
|
||||||
for name, hyps in hyps_dict.items():
|
for name, hyps in hyps_dict.items():
|
||||||
@ -419,6 +443,21 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
if params.LG is not None:
|
||||||
|
assert (
|
||||||
|
params.decoding_method == "modified_beam_search"
|
||||||
|
), "--LG is used only when --decoding_method=modified_beam_search"
|
||||||
|
logging.info(f"Loading LG from {params.LG}")
|
||||||
|
LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device))
|
||||||
|
logging.info(f"LG properties: {LG.properties_str}")
|
||||||
|
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
|
||||||
|
# If LG is created by local/compile_lg.py, then it should be epsilon
|
||||||
|
# free as well as arc sorted
|
||||||
|
assert "ArcSorted" in LG.properties_str
|
||||||
|
assert "EpsilonFree" in LG.properties_str
|
||||||
|
else:
|
||||||
|
LG = None
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
@ -467,6 +506,7 @@ def main():
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
|
LG=LG,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
78
egs/librispeech/ASR/transducer_stateless/shallow_fusion.py
Normal file
78
egs/librispeech/ASR/transducer_stateless/shallow_fusion.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def shallow_fusion(
|
||||||
|
LG: k2.Fsa,
|
||||||
|
token: int,
|
||||||
|
state_and_scores: Dict[int, torch.Tensor],
|
||||||
|
) -> Dict[int, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
LG:
|
||||||
|
An n-gram. It should be arc sorted and epsilon free.
|
||||||
|
token:
|
||||||
|
The input token ID.
|
||||||
|
state_and_scores:
|
||||||
|
The keys contain the current state we are in and the
|
||||||
|
values are the LM log_prob for reaching the corresponding
|
||||||
|
states from the start state.
|
||||||
|
Returns:
|
||||||
|
Return a new state_and_scores.
|
||||||
|
"""
|
||||||
|
row_splits = LG.arcs.row_splits(1)
|
||||||
|
arcs = LG.arcs.values()
|
||||||
|
|
||||||
|
current_states = list(state_and_scores.keys())
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
for s in current_states:
|
||||||
|
labels_begin = row_splits[s]
|
||||||
|
labels_end = row_splits[s + 1]
|
||||||
|
labels = LG.labels[labels_begin:labels_end].contiguous()
|
||||||
|
|
||||||
|
# As LG is not deterministic, there may be multiple
|
||||||
|
# out-going arcs that with label equal to "token"
|
||||||
|
#
|
||||||
|
# Note: LG is arc sorted!
|
||||||
|
left = torch.bucketize(token, labels, right=False)
|
||||||
|
right = torch.bucketize(token, labels, right=True)
|
||||||
|
|
||||||
|
if left >= right:
|
||||||
|
# There are no out-going arcs from this state
|
||||||
|
# that have label equal to "token"
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Now we have
|
||||||
|
# labels[i] == token
|
||||||
|
# for
|
||||||
|
# left <= i < right
|
||||||
|
|
||||||
|
for i in range(left, right):
|
||||||
|
i += labels_begin
|
||||||
|
next_state = arcs[i][1].item()
|
||||||
|
score = LG.scores[i]
|
||||||
|
if next_state not in ans:
|
||||||
|
ans[next_state] = score
|
||||||
|
else:
|
||||||
|
ans[next_state] = max(score, ans[next_state])
|
||||||
|
|
||||||
|
return ans
|
Loading…
x
Reference in New Issue
Block a user