mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 06:04:18 +00:00
Merge adb54aea91abe211b19ec75eeb422b15a3867405 into 6a091da0b0543befb0492848d3583700c274d111
This commit is contained in:
commit
1ce4349c17
144
egs/librispeech/ASR/local/compile_lg.py
Executable file
144
egs/librispeech/ASR/local/compile_lg.py
Executable file
@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates LG from
|
||||
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated LG is saved in $lang_dir/LG.fst
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_LG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_500.
|
||||
|
||||
Return:
|
||||
An FST representing LG.
|
||||
"""
|
||||
|
||||
tokens = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
|
||||
|
||||
assert "#0" in tokens
|
||||
|
||||
first_token_disambig_id = tokens["#0"]
|
||||
logging.info(f"first token disambig ID: {first_token_disambig_id}")
|
||||
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
logging.info("Loading pre-compiled G_3_gram")
|
||||
d = torch.load("data/lm/G_3_gram.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
del G.aux_labels
|
||||
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Composing L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}, num_arcs: {LG.num_arcs}")
|
||||
|
||||
del LG.aux_labels
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(
|
||||
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||
)
|
||||
|
||||
logging.info("Determinizing LG")
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(
|
||||
f"LG shape after k2.determinize: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||
)
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(
|
||||
f"LG shape after k2.connect: {LG.shape}, num_arcs: {LG.num_arcs}"
|
||||
)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info(f"LG properties: {LG.properties_str}")
|
||||
# Possible properties is:
|
||||
# "Valid|Nonempty|ArcSorted|ArcSortedAndDeterministic|EpsilonFree|MaybeAccessible|MaybeCoaccessible" # noqa
|
||||
logging.info(
|
||||
"Caution: LG is deterministic and contains disambig symbols!!!"
|
||||
)
|
||||
|
||||
return LG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
out_filename = lang_dir / "LG.pt"
|
||||
|
||||
if out_filename.is_file():
|
||||
logging.info(f"{out_filename} already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
LG = compile_LG(lang_dir)
|
||||
logging.info(f"Saving LG to {out_filename}")
|
||||
torch.save(LG.as_dict(), out_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
96
egs/librispeech/ASR/local/test_compile_lg.py
Executable file
96
egs/librispeech/ASR/local/test_compile_lg.py
Executable file
@ -0,0 +1,96 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./local/test_compile_lg.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
lang_dir = Path("./data/lang_bpe_500")
|
||||
corpus = "test_compile_lg_corpus.txt"
|
||||
arpa = "test_compile_lg_3_gram.arpa"
|
||||
G_fst_txt = "test_compile_lg_3_gram.fst.txt"
|
||||
|
||||
|
||||
def generate_corpus():
|
||||
s = """HELLO WORLD
|
||||
HELLOA WORLDER
|
||||
HELLOA WORLDER HELLO
|
||||
HELLOA WORLDER"""
|
||||
with open(corpus, "w") as f:
|
||||
f.write(s)
|
||||
|
||||
|
||||
def generate_arpa():
|
||||
cmd = f"""
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 3 \
|
||||
-text {corpus} \
|
||||
-lm {arpa}
|
||||
"""
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def generate_G():
|
||||
cmd = f"""
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="{lang_dir}/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
{arpa} > {G_fst_txt}
|
||||
"""
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def main():
|
||||
generate_corpus()
|
||||
generate_arpa()
|
||||
generate_G()
|
||||
with open(G_fst_txt) as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
del G.aux_labels
|
||||
G.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
||||
G.draw("G.pdf", title="G")
|
||||
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L.labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/tokens.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(f"{lang_dir}/words.txt")
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
LG = k2.compose(L, G)
|
||||
del LG.aux_labels
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
LG = k2.connect(LG)
|
||||
LG = k2.arc_sort(LG)
|
||||
print(LG.properties_str)
|
||||
LG.draw("LG.pdf", title="LG")
|
||||
# You can have a look at G.pdf and LG.pdf to get a feel
|
||||
# what they look like
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -14,11 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from model import Transducer
|
||||
from shallow_fusion import shallow_fusion
|
||||
from utils import Hypothesis, HypothesisList
|
||||
|
||||
|
||||
def greedy_search(
|
||||
@ -101,132 +103,6 @@ def greedy_search(
|
||||
return hyp
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
# The predicted tokens so far.
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
return "_".join(map(str, self.ys))
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
data:
|
||||
A dict of Hypotheses. Its key is its `value.key`.
|
||||
"""
|
||||
if data is None:
|
||||
self._data = {}
|
||||
else:
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
`log-sum-exp` with the existed one.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be added.
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||
"""Get the most probable hypothesis, i.e., the one with
|
||||
the largest `log_prob`.
|
||||
|
||||
Args:
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(
|
||||
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
|
||||
)
|
||||
else:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
|
||||
Caution:
|
||||
`self` is modified **in-place**.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be removed from `self`.
|
||||
Note: It must be contained in `self`. Otherwise,
|
||||
an exception is raised.
|
||||
"""
|
||||
key = hyp.key
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
|
||||
def topk(self, k: int) -> "HypothesisList":
|
||||
"""Return the top-k hypothesis."""
|
||||
hyps = list(self._data.items())
|
||||
|
||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||
|
||||
ans = HypothesisList(dict(hyps))
|
||||
return ans
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self._data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = []
|
||||
for key in self:
|
||||
s.append(key)
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def run_decoder(
|
||||
ys: List[int],
|
||||
model: Transducer,
|
||||
@ -421,6 +297,161 @@ def modified_beam_search(
|
||||
return ys
|
||||
|
||||
|
||||
def modified_beam_search_with_shallow_fusion(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
LG: Optional[k2.Fsa] = None,
|
||||
ngram_lm_scale: float = 0.1,
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
LG:
|
||||
Optional. Used for shallow fusion.
|
||||
ngram_lm_scale:
|
||||
Used only when LG is not None. The total score of a path is
|
||||
am_score + ngram_lm_scale * ngram_lm_scale
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
enable_shallow_fusion = LG is not None
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
|
||||
B = HypothesisList()
|
||||
|
||||
if enable_shallow_fusion:
|
||||
ngram_state_and_scores = {
|
||||
0: torch.zeros(1, dtype=torch.float32, device=device)
|
||||
}
|
||||
else:
|
||||
ngram_state_and_scores = None
|
||||
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
ngram_state_and_scores=ngram_state_and_scores,
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
decoder_out_len = torch.tensor([1])
|
||||
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
|
||||
# ys_log_probs contains both AM scores and LM scores
|
||||
ys_log_probs = torch.cat(
|
||||
[
|
||||
hyp.log_prob.reshape(1, 1)
|
||||
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values())
|
||||
for hyp in A
|
||||
]
|
||||
)
|
||||
# ys_log_probs is of shape (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyp in A],
|
||||
device=device,
|
||||
)
|
||||
# decoder_input is of shape (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
|
||||
|
||||
current_encoder_out = current_encoder_out.expand(
|
||||
decoder_out.size(0), 1, -1
|
||||
)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len.expand(decoder_out.size(0)),
|
||||
decoder_out_len.expand(decoder_out.size(0)),
|
||||
)
|
||||
vocab_size = logits.size(-1)
|
||||
# logits is of shape (num_hyps, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
|
||||
tot_log_probs = log_probs + ys_log_probs
|
||||
|
||||
_, topk_indexes = tot_log_probs.reshape(-1).topk(beam)
|
||||
topk_log_probs = log_probs.reshape(-1)[topk_indexes]
|
||||
|
||||
# topk_hyp_indexes are indexes into `A`
|
||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||
|
||||
topk_hyp_indexes, indexes = torch.sort(topk_hyp_indexes)
|
||||
topk_token_indexes = topk_token_indexes[indexes]
|
||||
topk_log_probs = topk_log_probs[indexes]
|
||||
|
||||
shape = k2.ragged.create_ragged_shape2(
|
||||
row_ids=topk_hyp_indexes.to(torch.int32),
|
||||
cached_tot_size=topk_hyp_indexes.numel(),
|
||||
)
|
||||
blank_log_probs = log_probs[topk_hyp_indexes, 0]
|
||||
|
||||
row_splits = shape.row_splits(1).tolist()
|
||||
num_rows = len(row_splits) - 1
|
||||
for i in range(num_rows):
|
||||
start = row_splits[i]
|
||||
end = row_splits[i + 1]
|
||||
if start >= end:
|
||||
# Discard A[i] as other hyps have higher log_probs
|
||||
continue
|
||||
tokens = topk_token_indexes[start:end]
|
||||
|
||||
hyps = shallow_fusion(
|
||||
LG,
|
||||
A[i],
|
||||
tokens,
|
||||
topk_log_probs[start:end],
|
||||
vocab_size,
|
||||
blank_log_probs[i],
|
||||
)
|
||||
for h in hyps:
|
||||
B.add(h)
|
||||
|
||||
if len(B) > beam:
|
||||
B = B.topk(beam, ngram_lm_scale=ngram_lm_scale)
|
||||
|
||||
best_hyp = B.get_most_probable(
|
||||
length_norm=True, ngram_lm_scale=ngram_lm_scale
|
||||
)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
return ys
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
|
@ -49,13 +49,19 @@ import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
modified_beam_search_with_shallow_fusion,
|
||||
)
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -140,6 +146,22 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--LG",
|
||||
type=str,
|
||||
help="""Path to LG.pt for shallow fusion.
|
||||
Used only when --decoding-method is modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""Used when only --LG is provided.
|
||||
The total score of a path is am_score + ngram_lm_scale * ngram_lm_score.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -212,6 +234,7 @@ def decode_one_batch(
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
LG: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
@ -234,6 +257,9 @@ def decode_one_batch(
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
LG:
|
||||
Optional. Used for shallow fusion. Used only when params.decoding_method
|
||||
is modified_beam_search.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -266,12 +292,25 @@ def decode_one_batch(
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
if LG is None:
|
||||
hyp = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
hyp = modified_beam_search_with_shallow_fusion(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
LG=LG,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -289,6 +328,7 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
LG: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
@ -301,6 +341,9 @@ def decode_dataset(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
LG:
|
||||
Optional. Used for shallow fusion. Used only when params.decoding_method
|
||||
is modified_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -329,6 +372,7 @@ def decode_dataset(
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
LG=LG,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
@ -428,6 +472,25 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
if params.LG is not None:
|
||||
assert (
|
||||
params.decoding_method == "modified_beam_search"
|
||||
), "--LG is used only when --decoding_method=modified_beam_search"
|
||||
logging.info(f"Loading LG from {params.LG}")
|
||||
LG = k2.Fsa.from_dict(torch.load(params.LG, map_location=device))
|
||||
logging.info(
|
||||
f"max: {LG.scores.max()}, min: {LG.scores.min()}, mean: {LG.scores.mean()}"
|
||||
)
|
||||
logging.info(f"LG properties: {LG.properties_str}")
|
||||
logging.info(f"LG num_states: {LG.shape[0]}, num_arcs: {LG.num_arcs}")
|
||||
# If LG is created by local/compile_lg.py, then it should be epsilon
|
||||
# free, deterministic, and arc sorted
|
||||
assert "ArcSorted" in LG.properties_str
|
||||
assert "EpsilonFree" in LG.properties_str
|
||||
assert "Deterministic" in LG.properties_str
|
||||
else:
|
||||
LG = None
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
@ -476,6 +539,7 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
LG=LG,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
153
egs/librispeech/ASR/transducer_stateless/shallow_fusion.py
Normal file
153
egs/librispeech/ASR/transducer_stateless/shallow_fusion.py
Normal file
@ -0,0 +1,153 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import copy
|
||||
from utils import Hypothesis, HypothesisList
|
||||
|
||||
|
||||
def shallow_fusion(
|
||||
LG: k2.Fsa,
|
||||
hyp: Hypothesis,
|
||||
tokens: torch.Tensor,
|
||||
log_probs: torch.Tensor,
|
||||
vocab_size: int,
|
||||
blank_log_prob: torch.Tensor,
|
||||
) -> HypothesisList:
|
||||
"""
|
||||
Args:
|
||||
LG:
|
||||
An n-gram. It should be arc sorted, deterministic, and epsilon free.
|
||||
It contains disambig IDs and back-off arcs.
|
||||
hyp:
|
||||
The current hypothesis.
|
||||
tokens:
|
||||
The possible tokens that will be expanded from the given `hyp`.
|
||||
It is a 1-D tensor of dtype torch.int32.
|
||||
log_probs:
|
||||
It contains the acoustic log probabilities of each path that
|
||||
is extended from `hyp.ys` with `tokens`.
|
||||
log_probs.shape == tokens.shape.
|
||||
vocab_size:
|
||||
Vocabulary size, including the blank symbol. We assume that
|
||||
token IDs >= vocab_size are disambig IDs (including the backoff
|
||||
symbol #0).
|
||||
blank_log_prob:
|
||||
The log_prob for the blank token at this frame. It is from
|
||||
the output of the joiner.
|
||||
Returns:
|
||||
Return new hypotheses by extending the given `hyp` with tokens in the
|
||||
given `tokens`.
|
||||
"""
|
||||
|
||||
row_splits = LG.arcs.row_splits(1)
|
||||
arcs = LG.arcs.values()
|
||||
|
||||
state_and_scores = copy.deepcopy(hyp.ngram_state_and_scores)
|
||||
|
||||
current_states = list(state_and_scores.keys())
|
||||
|
||||
# Process out-going arcs with label equal to disambig tokens or #0
|
||||
while len(current_states) > 0:
|
||||
s = current_states.pop()
|
||||
labels_begin = row_splits[s]
|
||||
labels_end = row_splits[s + 1]
|
||||
labels = LG.labels[labels_begin:labels_end].contiguous()
|
||||
|
||||
for i in reversed(range(labels.numel())):
|
||||
lab = labels[i]
|
||||
if lab == -1:
|
||||
# Note: When sorting arcs, k2 treats arc labels as
|
||||
# unsigned types
|
||||
continue
|
||||
|
||||
if lab < vocab_size:
|
||||
# Since LG is arc sorted, we can exit
|
||||
# the for loop as soon as we have a label
|
||||
# with ID less than vocab_size
|
||||
break
|
||||
|
||||
# This is a diambig token or #0
|
||||
idx = labels_begin + i
|
||||
next_state = arcs[idx][1].item()
|
||||
score = LG.scores[idx] + state_and_scores[s]
|
||||
if next_state not in state_and_scores:
|
||||
state_and_scores[next_state] = score
|
||||
current_states.append(next_state)
|
||||
else:
|
||||
state_and_scores[next_state] = max(
|
||||
score, state_and_scores[next_state]
|
||||
)
|
||||
|
||||
current_states = list(state_and_scores.keys())
|
||||
ans = HypothesisList()
|
||||
|
||||
device = log_probs.device
|
||||
for s in current_states:
|
||||
labels_begin = row_splits[s]
|
||||
labels_end = row_splits[s + 1]
|
||||
labels = LG.labels[labels_begin:labels_end].contiguous()
|
||||
|
||||
if labels[-1] == -1:
|
||||
labels = labels[:-1]
|
||||
|
||||
if s != 0:
|
||||
# We add a backoff arc to the start state. Otherwise,
|
||||
# all activate state may die due to out-of-Vocabulary word.
|
||||
new_hyp = Hypothesis(
|
||||
ys=hyp.ys[:],
|
||||
log_prob=hyp.log_prob + blank_log_prob,
|
||||
ngram_state_and_scores={
|
||||
# -20 is the cost on the backoff arc to the start state.
|
||||
# As LG.scores.min() is about -16.6, we choose -20 here.
|
||||
# You may need to tune this value.
|
||||
0: torch.full((1,), -20, dtype=torch.float32, device=device)
|
||||
},
|
||||
)
|
||||
ans.add(new_hyp)
|
||||
|
||||
pos = torch.searchsorted(labels, tokens)
|
||||
for i in range(pos.numel()):
|
||||
if tokens[i] == 0:
|
||||
# blank ID
|
||||
new_hyp = Hypothesis(
|
||||
ys=hyp.ys[:],
|
||||
log_prob=hyp.log_prob + log_probs[i],
|
||||
ngram_state_and_scores=hyp.ngram_state_and_scores,
|
||||
)
|
||||
ans.add(new_hyp)
|
||||
continue
|
||||
elif pos[i] >= labels.numel() or labels[pos[i]] != tokens[i]:
|
||||
# No out-going arcs from this state has labels
|
||||
# equal to tokens[i]
|
||||
continue
|
||||
|
||||
# Found one arc
|
||||
|
||||
idx = labels_begin + pos[i]
|
||||
next_state = arcs[idx][1].item()
|
||||
score = LG.scores[idx] + state_and_scores[s]
|
||||
new_hyp = Hypothesis(
|
||||
ys=hyp.ys + [tokens[i].item()],
|
||||
log_prob=hyp.log_prob + log_probs[i],
|
||||
ngram_state_and_scores={next_state: score},
|
||||
)
|
||||
ans.add(new_hyp)
|
||||
|
||||
return ans
|
219
egs/librispeech/ASR/transducer_stateless/utils.py
Normal file
219
egs/librispeech/ASR/transducer_stateless/utils.py
Normal file
@ -0,0 +1,219 @@
|
||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
# The predicted tokens so far.
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
# Note: It contains only the acoustic part.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
# Used for shallow fusion
|
||||
# The key of the dict is a state index into LG
|
||||
# while the corresponding value is the LM score
|
||||
# reaching this state from the start state.
|
||||
# Note: The value tensor contains only a single entry
|
||||
# and it contains only the LM part.
|
||||
ngram_state_and_scores: Optional[Dict[int, torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
return "_".join(map(str, self.ys))
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
data:
|
||||
A dict of Hypotheses. Its key is its `value.key`.
|
||||
"""
|
||||
if data is None:
|
||||
self._data = {}
|
||||
else:
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
`log-sum-exp` with the existed one.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be added.
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
|
||||
if False:
|
||||
old_hyp.log_prob = torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob
|
||||
)
|
||||
else:
|
||||
old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob)
|
||||
|
||||
if hyp.ngram_state_and_scores is not None:
|
||||
for state, score in hyp.ngram_state_and_scores.items():
|
||||
if (
|
||||
state in old_hyp.ngram_state_and_scores
|
||||
and score > old_hyp.ngram_state_and_scores[state]
|
||||
):
|
||||
old_hyp.ngram_state_and_scores[state] = score
|
||||
else:
|
||||
old_hyp.ngram_state_and_scores[state] = score
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
def get_most_probable(
|
||||
self, length_norm: bool = False, ngram_lm_scale: Optional[float] = None
|
||||
) -> Hypothesis:
|
||||
"""Get the most probable hypothesis, i.e., the one with
|
||||
the largest `log_prob`.
|
||||
|
||||
Args:
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
ngram_lm_scale:
|
||||
If not None, it specifies the scale applied to the LM score.
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
if ngram_lm_scale is None:
|
||||
return max(
|
||||
self._data.values(),
|
||||
key=lambda hyp: hyp.log_prob / len(hyp.ys),
|
||||
)
|
||||
else:
|
||||
return max(
|
||||
self._data.values(),
|
||||
key=lambda hyp: (
|
||||
hyp.log_prob
|
||||
+ ngram_lm_scale
|
||||
* max(hyp.ngram_state_and_scores.values())
|
||||
)
|
||||
/ len(hyp.ys),
|
||||
)
|
||||
else:
|
||||
if ngram_lm_scale is None:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||
else:
|
||||
return max(
|
||||
self._data.values(),
|
||||
key=lambda hyp: hyp.log_prob
|
||||
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values()),
|
||||
)
|
||||
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
|
||||
Caution:
|
||||
`self` is modified **in-place**.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be removed from `self`.
|
||||
Note: It must be contained in `self`. Otherwise,
|
||||
an exception is raised.
|
||||
"""
|
||||
key = hyp.key
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(
|
||||
self, threshold: torch.Tensor, ngram_lm_scale: Optional[float] = None
|
||||
) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||
|
||||
Args:
|
||||
threshold:
|
||||
Hypotheses with log_prob less than this value are removed.
|
||||
ngram_lm_scale:
|
||||
If not None, it specifies the scale applied to the LM score.
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
if ngram_lm_scale is None:
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
else:
|
||||
for _, hyp in self._data.items():
|
||||
if (
|
||||
hyp.log_prob
|
||||
+ ngram_lm_scale * max(hyp.ngram_state_and_scores.values())
|
||||
> threshold
|
||||
):
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
|
||||
def topk(
|
||||
self, k: int, ngram_lm_scale: Optional[float] = None
|
||||
) -> "HypothesisList":
|
||||
"""Return the top-k hypothesis."""
|
||||
hyps = list(self._data.items())
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||
else:
|
||||
hyps = sorted(
|
||||
hyps,
|
||||
key=lambda h: h[1].log_prob
|
||||
+ ngram_lm_scale * max(h[1].ngram_state_and_scores.values()),
|
||||
reverse=True,
|
||||
)[:k]
|
||||
|
||||
ans = HypothesisList(dict(hyps))
|
||||
return ans
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self._data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = []
|
||||
for key in self:
|
||||
s.append(key)
|
||||
return ", ".join(s)
|
Loading…
x
Reference in New Issue
Block a user