mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
* Fix an error in TDNN-LSTM training. * WIP: Refactoring * Refactor transformer.py * Remove unused code. * Minor fixes. * Fix decoder padding mask. * Add MMI training with word pieces. * Remove unused files. * Minor fixes. * Refactoring. * Minor fixes. * Use pre-computed alignments in LF-MMI training. * Minor fixes. * Update decoding script. * Add doc about how to check and use extracted alignments. * Fix style issues. * Fix typos. * Fix style issues. * Disable macOS tests for now.
222 lines
7.7 KiB
Python
222 lines
7.7 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from typing import Iterable, List, Tuple, Union
|
|
|
|
import k2
|
|
import torch
|
|
|
|
from icefall.lexicon import UniqLexicon
|
|
|
|
|
|
class MmiTrainingGraphCompiler(object):
|
|
def __init__(
|
|
self,
|
|
lang_dir: Path,
|
|
uniq_filename: str = "uniq_lexicon.txt",
|
|
device: Union[str, torch.device] = "cpu",
|
|
oov: str = "<UNK>",
|
|
sos_id: int = 1,
|
|
eos_id: int = 1,
|
|
):
|
|
"""
|
|
Args:
|
|
lang_dir:
|
|
Path to the lang directory. It is expected to contain the
|
|
following files::
|
|
|
|
- tokens.txt
|
|
- words.txt
|
|
- P.fst.txt
|
|
|
|
The above files are generated by the script `prepare.sh`. You
|
|
should have run it before running the training code.
|
|
uniq_filename:
|
|
File name to the lexicon in which every word has exactly one
|
|
pronunciation. We assume this file is inside the given `lang_dir`.
|
|
|
|
device:
|
|
It indicates CPU or CUDA.
|
|
oov:
|
|
Out of vocabulary word. When a word in the transcript
|
|
does not exist in the lexicon, it is replaced with `oov`.
|
|
"""
|
|
self.lang_dir = Path(lang_dir)
|
|
self.lexicon = UniqLexicon(lang_dir, uniq_filename=uniq_filename)
|
|
self.device = torch.device(device)
|
|
|
|
self.L_inv = self.lexicon.L_inv.to(self.device)
|
|
|
|
self.oov_id = self.lexicon.word_table[oov]
|
|
self.sos_id = sos_id
|
|
self.eos_id = eos_id
|
|
|
|
self.build_ctc_topo_P()
|
|
|
|
def build_ctc_topo_P(self):
|
|
"""Built ctc_topo_P, the composition result of
|
|
ctc_topo and P, where P is a pre-trained bigram
|
|
word piece LM.
|
|
"""
|
|
# Note: there is no need to save a pre-compiled P and ctc_topo
|
|
# as it is very fast to generate them.
|
|
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
|
|
with open(self.lang_dir / "P.fst.txt") as f:
|
|
# P is not an acceptor because there is
|
|
# a back-off state, whose incoming arcs
|
|
# have label #0 and aux_label 0 (i.e., <eps>).
|
|
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
|
|
|
first_token_disambig_id = self.lexicon.token_table["#0"]
|
|
|
|
# P.aux_labels is not needed in later computations, so
|
|
# remove it here.
|
|
del P.aux_labels
|
|
# CAUTION: The following line is crucial.
|
|
# Arcs entering the back-off state have label equal to #0.
|
|
# We have to change it to 0 here.
|
|
P.labels[P.labels >= first_token_disambig_id] = 0
|
|
|
|
P = k2.remove_epsilon(P)
|
|
P = k2.arc_sort(P)
|
|
P = P.to(self.device)
|
|
# Add epsilon self-loops to P because we want the
|
|
# following operation "k2.intersect" to run on GPU.
|
|
P_with_self_loops = k2.add_epsilon_self_loops(P)
|
|
|
|
max_token_id = max(self.lexicon.tokens)
|
|
logging.info(
|
|
f"Building ctc_topo (modified=False). max_token_id: {max_token_id}"
|
|
)
|
|
ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device)
|
|
|
|
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
|
|
|
|
logging.info("Building ctc_topo_P")
|
|
ctc_topo_P = k2.intersect(
|
|
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
|
|
).invert()
|
|
|
|
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
|
|
logging.info(f"ctc_topo_P num_arcs: {self.ctc_topo_P.num_arcs}")
|
|
|
|
def compile(
|
|
self, texts: Iterable[str], replicate_den: bool = True
|
|
) -> Tuple[k2.Fsa, k2.Fsa]:
|
|
"""Create numerator and denominator graphs from transcripts
|
|
and the bigram phone LM.
|
|
|
|
Args:
|
|
texts:
|
|
A list of transcripts. Within a transcript, words are
|
|
separated by spaces. An example `texts` is given below::
|
|
|
|
["Hello icefall", "LF-MMI training with icefall using k2"]
|
|
|
|
replicate_den:
|
|
If True, the returned den_graph is replicated to match the number
|
|
of FSAs in the returned num_graph; if False, the returned den_graph
|
|
contains only a single FSA
|
|
Returns:
|
|
A tuple (num_graph, den_graph), where
|
|
|
|
- `num_graph` is the numerator graph. It is an FsaVec with
|
|
shape `(len(texts), None, None)`.
|
|
|
|
- `den_graph` is the denominator graph. It is an FsaVec
|
|
with the same shape of the `num_graph` if replicate_den is
|
|
True; otherwise, it is an FsaVec containing only a single FSA.
|
|
"""
|
|
transcript_fsa = self.build_transcript_fsa(texts)
|
|
|
|
# remove word IDs from transcript_fsa since it is not needed
|
|
del transcript_fsa.aux_labels
|
|
# NOTE: You can comment out the above statement
|
|
# if you want to run test/test_mmi_graph_compiler.py
|
|
|
|
transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
|
|
transcript_fsa
|
|
)
|
|
|
|
transcript_fsa_with_self_loops = k2.arc_sort(
|
|
transcript_fsa_with_self_loops
|
|
)
|
|
|
|
num = k2.compose(
|
|
self.ctc_topo_P,
|
|
transcript_fsa_with_self_loops,
|
|
treat_epsilons_specially=False,
|
|
)
|
|
|
|
# CAUTION: Due to the presence of P,
|
|
# the resulting `num` may not be connected
|
|
num = k2.connect(num)
|
|
|
|
num = k2.arc_sort(num)
|
|
|
|
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
|
|
if replicate_den:
|
|
indexes = torch.zeros(
|
|
len(texts), dtype=torch.int32, device=self.device
|
|
)
|
|
den = k2.index_fsa(ctc_topo_P_vec, indexes)
|
|
else:
|
|
den = ctc_topo_P_vec
|
|
|
|
return num, den
|
|
|
|
def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa:
|
|
"""Convert transcripts to an FsaVec with the help of a lexicon
|
|
and word symbol table.
|
|
|
|
Args:
|
|
texts:
|
|
Each element is a transcript containing words separated by space(s).
|
|
For instance, it may be 'HELLO icefall', which contains
|
|
two words.
|
|
|
|
Returns:
|
|
Return an FST (FsaVec) corresponding to the transcript.
|
|
Its `labels` is token IDs and `aux_labels` is word IDs.
|
|
"""
|
|
word_ids_list = []
|
|
for text in texts:
|
|
word_ids = []
|
|
for word in text.split():
|
|
if word in self.lexicon.word_table:
|
|
word_ids.append(self.lexicon.word_table[word])
|
|
else:
|
|
word_ids.append(self.oov_id)
|
|
word_ids_list.append(word_ids)
|
|
|
|
fsa = k2.linear_fsa(word_ids_list, self.device)
|
|
fsa = k2.add_epsilon_self_loops(fsa)
|
|
|
|
# The reason to use `invert_()` at the end is as follows:
|
|
#
|
|
# (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs
|
|
# (2) `fsa.labels` is word IDs
|
|
# (3) after intersection, the `labels` is still word IDs
|
|
# (4) after `invert_()`, the `labels` is token IDs
|
|
# and `aux_labels` is word IDs
|
|
transcript_fsa = k2.intersect(
|
|
self.L_inv, fsa, treat_epsilons_specially=False
|
|
).invert_()
|
|
transcript_fsa = k2.arc_sort(transcript_fsa)
|
|
return transcript_fsa
|
|
|
|
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
|
"""Convert a list of texts to a list-of-list of piece IDs.
|
|
|
|
Args:
|
|
texts:
|
|
It is a list of strings. Each string consists of space(s)
|
|
separated words. An example containing two strings is given below:
|
|
|
|
['HELLO ICEFALL', 'HELLO k2']
|
|
We assume it contains no OOVs. Otherwise, it will raise an
|
|
exception.
|
|
Returns:
|
|
Return a list-of-list of token IDs.
|
|
"""
|
|
return self.lexicon.texts_to_token_ids(texts).tolist()
|