Remove optional silence (SIL).

This commit is contained in:
Fangjun Kuang 2021-09-10 16:24:49 +08:00
parent 31b3e5b27a
commit 5390ced2d1
15 changed files with 413 additions and 399 deletions

View File

@ -61,29 +61,6 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--use-ali-model",
type=str2bool,
default=False,
help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py "
"and you have some checkpoints inside the directory "
"tdnn_lstm_ctc/exp_bpe_500 ."
"It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt "
"as the pre-trained alignment model",
)
parser.add_argument(
"--ali-model-epoch",
type=int,
default=19,
help="If --use-ali-model is True, load "
"tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as "
"the alignment model."
"Used only if --use-ali-model is True.",
)
# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser return parser
@ -158,7 +135,7 @@ def get_params() -> AttributeDict:
"use_pruned_intersect": False, "use_pruned_intersect": False,
"den_scale": 1.0, "den_scale": 1.0,
# #
"att_rate": 0, "att_rate": 0, # If not zero, use attention decoder
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_decoder_layers": 0, "num_decoder_layers": 0,
@ -166,7 +143,6 @@ def get_params() -> AttributeDict:
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"lr_factor": 5.0, "lr_factor": 5.0,
"warm_step": 80000, "warm_step": 80000,
# "warm_step": 10000,
} }
) )
@ -260,10 +236,9 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def compute_loss_impl( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module],
batch: dict, batch: dict,
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
is_training: bool, is_training: bool,
@ -296,22 +271,6 @@ def compute_loss_impl(
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is [N, T, C]
if ali_model is not None and params.batch_idx_train < 4000:
feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T]
ali_model_output = ali_model(feature)
# subsampling is done slightly differently, may be small length
# differences.
min_len = min(ali_model_output.shape[1], nnet_output.shape[1])
# scale less than one so it will be encouraged
# to mimic ali_model's output
ali_model_scale = 500.0 / (params.batch_idx_train + 500)
# Use clone() here or log-softmax backprop will fail.
nnet_output = nnet_output.clone()
nnet_output[:, :min_len, :] += (
ali_model_scale * ali_model_output[:, :min_len, :]
)
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by
@ -374,58 +333,9 @@ def compute_loss_impl(
return loss, mmi_loss.detach(), att_loss.detach() return loss, mmi_loss.detach(), att_loss.detach()
def compute_loss(
params: AttributeDict,
model: nn.Module,
ali_model: Optional[nn.Module],
batch: dict,
graph_compiler: MmiTrainingGraphCompiler,
is_training: bool,
):
try:
return compute_loss_impl(
params=params,
model=model,
ali_model=ali_model,
batch=batch,
graph_compiler=graph_compiler,
is_training=is_training,
)
except RuntimeError as ex:
if "out of memory" not in str(ex):
raise ex
logging.exception(ex)
s = f"\nCaught exception: {str(ex)}\n"
total_duration = 0.0
max_cut_duration = 0.0
for cut in batch["supervisions"]["cut"]:
s += f" id: {cut.id}, duration: {cut.duration} seconds\n"
total_duration += cut.duration
max_cut_duration = max(max_cut_duration, cut.duration)
s += f" total duration: {total_duration:.3f} s\n"
s += f" max duration: {max_cut_duration:.3f} s \n"
logging.info(s)
torch.cuda.empty_cache()
gc.collect()
# See https://github.com/pytorch/pytorch/issues/18853#issuecomment-583779161
return compute_loss_impl(
params=params,
model=model,
ali_model=ali_model,
batch=params.saved_batch,
graph_compiler=graph_compiler,
is_training=is_training,
)
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module],
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
@ -443,7 +353,6 @@ def compute_validation_loss(
loss, mmi_loss, att_loss = compute_loss( loss, mmi_loss, att_loss = compute_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=False, is_training=False,
@ -484,7 +393,6 @@ def compute_validation_loss(
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
graph_compiler: MmiTrainingGraphCompiler, graph_compiler: MmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
@ -503,9 +411,6 @@ def train_one_epoch(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The model for training. The model for training.
ali_model:
The force alignment model for training. It is from
tdnn_lstm_ctc/train_bpe.py
optimizer: optimizer:
The optimizer we are using. The optimizer we are using.
graph_compiler: graph_compiler:
@ -529,18 +434,12 @@ def train_one_epoch(
params.tot_loss = 0.0 params.tot_loss = 0.0
params.tot_frames = 0.0 params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx == 0:
logging.info("save a batch for OOM handling")
# Use this batch to replace the batch that's causing OOM
params.saved_batch = batch
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, mmi_loss, att_loss = compute_loss( loss, mmi_loss, att_loss = compute_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
@ -632,7 +531,6 @@ def train_one_epoch(
compute_validation_loss( compute_validation_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
@ -669,9 +567,6 @@ def train_one_epoch(
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
if "saved_batch" in params:
del params["saved_batch"]
def run(rank, world_size, args): def run(rank, world_size, args):
""" """
@ -745,35 +640,6 @@ def run(rank, world_size, args):
if checkpoints and checkpoints["optimizer"]: if checkpoints and checkpoints["optimizer"]:
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
assert args.use_ali_model is False
if args.use_ali_model:
ali_model = TdnnLstm(
num_features=params.feature_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
)
# TODO: add an option to switch among
# bpe_500, bpe_1000, and bpe_5000
ali_model_fname = Path(
f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt"
)
assert (
ali_model_fname.is_file()
), f"ali model filename {ali_model_fname} does not exist!"
ali_model.load_state_dict(
torch.load(ali_model_fname, map_location="cpu")["model"]
)
ali_model.to(device)
ali_model.eval()
ali_model.requires_grad_(False)
logging.info(f"Use ali_model: {ali_model_fname}")
else:
ali_model = None
logging.info("No ali_model")
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders() train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() valid_dl = librispeech.valid_dataloaders()
@ -796,7 +662,6 @@ def run(rank, world_size, args):
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
optimizer=optimizer, optimizer=optimizer,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
train_dl=train_dl, train_dl=train_dl,

1
egs/librispeech/ASR/local/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
tmp_lang

View File

@ -8,8 +8,8 @@ for LM training with the help of a lexicon.
If the lexicon contains phones, the resulting LM will be a phone LM; If the If the lexicon contains phones, the resulting LM will be a phone LM; If the
lexicon contains word pieces, the resulting LM will be a word piece LM. lexicon contains word pieces, the resulting LM will be a word piece LM.
If a word has multiple pronunciations, the one that appears last in the lexicon If a word has multiple pronunciations, the one that appears first in the lexicon
is used. is kept; others are removed.
If the input transcript is: If the input transcript is:
@ -20,8 +20,8 @@ If the input transcript is:
and if the lexicon is and if the lexicon is
<UNK> SPN <UNK> SPN
hello h e l l o
hello h e l l o 2 hello h e l l o 2
hello h e l l o
world w o r l d world w o r l d
zoo z o o zoo z o o
@ -36,6 +36,8 @@ import argparse
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
from generate_unique_lexicon import filter_multiple_pronunications
from icefall.lexicon import read_lexicon from icefall.lexicon import read_lexicon
@ -87,8 +89,10 @@ def main():
assert Path(args.transcript).is_file() assert Path(args.transcript).is_file()
assert len(args.oov) > 0 assert len(args.oov) > 0
# Only the last pronunciation of a word is kept # Only the first pronunciation of a word is kept
lexicon = dict(read_lexicon(args.lexicon)) lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
lexicon = dict(lexicon)
assert args.oov in lexicon assert args.oov in lexicon

View File

@ -25,7 +25,7 @@ This file downloads the following LibriSpeech LM files:
- librispeech-lexicon.txt - librispeech-lexicon.txt
from http://www.openslr.org/resources/11 from http://www.openslr.org/resources/11
and save them in the user provided directory. and saves them in the user provided directory.
Files are not re-downloaded if they already exist. Files are not re-downloaded if they already exist.

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file takes as input a lexicon.txt and output a new lexicon,
in which each word has a unique pronunciation.
The way to do this is to keep only the first pronunciation of a word
in lexicon.txt.
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
from icefall.lexicon import read_lexicon, write_lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
This file will generate a new file uniq_lexicon.txt
in it.
""",
)
return parser.parse_args()
def filter_multiple_pronunications(
lexicon: List[Tuple[str, List[str]]]
) -> List[Tuple[str, List[str]]]:
"""Remove multiple pronunciations of words from a lexicon.
If a word has more than one pronunciation in the lexicon, only
the first one is kept, while other pronunciations are removed
from the lexicon.
Args:
lexicon:
The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
where "p1, p2, ..., pn" are the pronunciations of the "word".
Returns:
Return a new lexicon where each word has a unique pronunciation.
"""
seen = set()
ans = []
for word, tokens in lexicon:
if word in seen:
continue
seen.add(word)
ans.append((word, tokens))
return ans
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
in_lexicon = read_lexicon(lexicon_filename)
out_lexicon = filter_multiple_pronunications(in_lexicon)
write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -17,8 +17,9 @@
""" """
This script takes as input a lexicon file "data/lang_phone/lexicon.txt" This script takes as input a `lang_dir`, which is expected to contain
consisting of words and tokens (i.e., phones) and does the following: a lexicon file "lexicon.txt" consisting of words and tokens (i.e., phones)
and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt 1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
@ -36,11 +37,12 @@ consisting of words and tokens (i.e., phones) and does the following:
The generated files are saved into `lang_dir`. The generated files are saved into `lang_dir`.
""" """
import argparse import argparse
import math
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
from icefall.utils import str2bool
import k2 import k2
import torch import torch
@ -55,7 +57,22 @@ def get_args():
"--lang-dir", "--lang-dir",
type=str, type=str,
help="""Input and output directory. help="""Input and output directory.
It should contain a file lexicon.txt It should contain a file lexicon.txt.
Generated files by this script are saved into this directory.
""",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
See "local/test_prepare_lang.sh" for usage.
""", """,
) )
@ -85,6 +102,10 @@ def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
def get_tokens(lexicon: Lexicon) -> List[str]: def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon. """Get tokens from a lexicon.
If pronunciations are phones, then tokens are phones.
If pronunciations are word pieces, then tokens are word pieces.
Args: Args:
lexicon: lexicon:
It is the return value of :func:`read_lexicon`. It is the return value of :func:`read_lexicon`.
@ -208,6 +229,9 @@ def add_self_loops(
The input label of a self-loop is `disambig_token`, while the output The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`. label is `disambig_word`.
Caution:
Don't be confused with :func:`k2.add_epsilon_self_loops`.
Args: Args:
arcs: arcs:
A list-of-list. The sublist contains A list-of-list. The sublist contains
@ -237,12 +261,9 @@ def lexicon_to_fst(
lexicon: Lexicon, lexicon: Lexicon,
token2id: Dict[str, int], token2id: Dict[str, int],
word2id: Dict[str, int], word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False, need_self_loops: bool = False,
) -> k2.Fsa: ) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at """Convert a lexicon to an FST (in k2 format).
the beginning and end of each word.
Args: Args:
lexicon: lexicon:
@ -251,11 +272,6 @@ def lexicon_to_fst(
A dict mapping tokens to IDs. A dict mapping tokens to IDs.
word2id: word2id:
A dict mapping words to IDs. A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops: need_self_loops:
If True, add self-loop to states with non-epsilon output symbols If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this on at least one arc out of the state. The input label for this
@ -263,50 +279,43 @@ def lexicon_to_fst(
Returns: Returns:
Return an instance of `k2.Fsa` representing the given lexicon. Return an instance of `k2.Fsa` representing the given lexicon.
""" """
assert sil_prob > 0.0 and sil_prob < 1.0 loop_state = 0 # words enter and leave from here
# CAUTION: we use score, i.e, negative cost. next_state = 1 # the next un-allocated state, will be incremented as we go
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = [] arcs = []
assert token2id["<eps>"] == 0 if "<blk>" in token2id:
# For BPE based lexicon
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
else:
# For phone based lexicon in the CTC topo,
# 0 on the left side (i.e., as label) indicates a blank.
# 0 on the right side (i.e., as aux_label) represents an epsilon
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0 assert word2id["<eps>"] == 0
eps = 0 eps = 0
sil_token = token2id[sil_token] for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
tokens = [token2id[i] for i in tokens] pieces = [token2id[i] for i in pieces]
for i in range(len(tokens) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0]) arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state cur_state = next_state
next_state += 1 next_state += 1
# now for the last token of this word # now for the last piece of this word
# It has two out-going arcs, one to the loop state, i = len(pieces) - 1
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) arcs.append([cur_state, loop_state, pieces[i], w, 0])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops: if need_self_loops:
disambig_token = token2id["#0"] disambig_token = token2id["#0"]
@ -335,8 +344,6 @@ def main():
lang_dir = Path(args.lang_dir) lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt" lexicon_filename = lang_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename) lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon) tokens = get_tokens(lexicon)
@ -370,21 +377,29 @@ def main():
lexicon, lexicon,
token2id=token2id, token2id=token2id,
word2id=word2id, word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
) )
L_disambig = lexicon_to_fst( L_disambig = lexicon_to_fst(
lexicon_disambig, lexicon_disambig,
token2id=token2id, token2id=token2id,
word2id=word2id, word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True, need_self_loops=True,
) )
torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -44,86 +44,12 @@ import torch
from prepare_lang import ( from prepare_lang import (
Lexicon, Lexicon,
add_disambig_symbols, add_disambig_symbols,
add_self_loops, lexicon_to_fst,
write_lexicon, write_lexicon,
write_mapping, write_mapping,
) )
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] for i in pieces]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def generate_lexicon( def generate_lexicon(
model_file: str, words: List[str] model_file: str, words: List[str]
) -> Tuple[Lexicon, Dict[str, int]]: ) -> Tuple[Lexicon, Dict[str, int]]:
@ -206,13 +132,13 @@ def main():
write_lexicon(lang_dir / "lexicon.txt", lexicon) write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil( L = lexicon_to_fst(
lexicon, lexicon,
token2id=token_sym_table, token2id=token_sym_table,
word2id=word_sym_table, word2id=word_sym_table,
) )
L_disambig = lexicon_to_fst_no_sil( L_disambig = lexicon_to_fst(
lexicon_disambig, lexicon_disambig,
token2id=token_sym_table, token2id=token_sym_table,
word2id=word_sym_table, word2id=word_sym_table,

View File

@ -1,106 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import os
import tempfile
import k2
from prepare_lang import (
add_disambig_symbols,
generate_id_map,
get_phones,
get_words,
lexicon_to_fst,
read_lexicon,
write_lexicon,
write_mapping,
)
def generate_lexicon_file() -> str:
fd, filename = tempfile.mkstemp()
os.close(fd)
s = """
!SIL SIL
<SPOKEN_NOISE> SPN
<UNK> SPN
f f
a a
foo f o o
bar b a r
bark b a r k
food f o o d
food2 f o o d
fo f o
""".strip()
with open(filename, "w") as f:
f.write(s)
return filename
def test_read_lexicon(filename: str):
lexicon = read_lexicon(filename)
phones = get_phones(lexicon)
words = get_words(lexicon)
print(lexicon)
print(phones)
print(words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
print(lexicon_disambig)
print("max disambig:", f"#{max_disambig}")
phones = ["<eps>", "SIL", "SPN"] + phones
for i in range(max_disambig + 1):
phones.append(f"#{i}")
words = ["<eps>"] + words
phone2id = generate_id_map(phones)
word2id = generate_id_map(words)
print(phone2id)
print(word2id)
write_mapping("phones.txt", phone2id)
write_mapping("words.txt", word2id)
write_lexicon("a.txt", lexicon)
write_lexicon("a_disambig.txt", lexicon_disambig)
fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id)
fsa.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa.draw("L.pdf", title="L")
fsa_disambig = lexicon_to_fst(
lexicon_disambig, phone2id=phone2id, word2id=word2id
)
fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt")
fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt")
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
def main():
filename = generate_lexicon_file()
test_read_lexicon(filename)
os.remove(filename)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,37 @@
#!/usr/bin/env bash
lang_dir=tmp_lang
mkdir -p $lang_dir
cat <<EOF > $lang_dir/lexicon.txt
<UNK> SPN
f f
a a
foo f o o
bar b a r
bark b a r k
food f o o d
food2 f o o d
fo f o
fo f o o
EOF
./prepare_lang.py --lang-dir $lang_dir --debug 1
./generate_unique_lexicon.py --lang-dir $lang_dir
cat <<EOF > $lang_dir/transcript_words.txt
foo bar bark food food2 fo f a foo bar
bar food2 fo bark
EOF
./convert_transcript_words_to_tokens.py \
--lexicon $lang_dir/uniq_lexicon.txt \
--transcript $lang_dir/transcript_words.txt \
--oov "<UNK>" \
> $lang_dir/transcript_tokens.txt
../shared/make_kn_lm.py \
-ngram-order 2 \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/P.arpa
echo "Please delete the directory '$lang_dir' manually"

View File

@ -38,7 +38,7 @@ def get_args():
"--lang-dir", "--lang-dir",
type=str, type=str,
help="""Input and output directory. help="""Input and output directory.
It should contain the training corpus: train.txt. It should contain the training corpus: transcript_words.txt.
The generated bpe.model is saved to this directory. The generated bpe.model is saved to this directory.
""", """,
) )
@ -59,7 +59,7 @@ def main():
model_type = "unigram" model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = f"{lang_dir}/train.txt" train_text = f"{lang_dir}/transcript_words.txt"
character_coverage = 1.0 character_coverage = 1.0
input_sentence_size = 100000000 input_sentence_size = 100000000

View File

@ -116,16 +116,18 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
lang_dir=data/lang_phone lang_dir=data/lang_phone
mkdir -p $lang_dir mkdir -p $lang_dir
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) | echo '<UNK> SPN' |
cat - $dl_dir/lm/librispeech-lexicon.txt | cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > $lang_dir/lexicon.txt sort | uniq > $lang_dir/lexicon.txt
./local/generate_unique_lexicon.py --lang-dir $lang_dir
if [ ! -f $lang_dir/L_disambig.pt ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir ./local/prepare_lang.py --lang-dir $lang_dir
fi fi
# Train a bigram P for MMI training # Train a bigram P for MMI training
if [ ! -f $lang_dir/train.txt ]; then if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data to train phone based bigram P" log "Generate data to train phone based bigram P"
files=$( files=$(
find -L "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" find -L "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -134,30 +136,21 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
) )
for f in ${files[@]}; do for f in ${files[@]}; do
cat $f | cut -d " " -f 2- cat $f | cut -d " " -f 2-
done > $lang_dir/train.txt done > $lang_dir/transcript_words.txt
fi fi
if [ ! -f $lang_dir/train_with_sil.txt ]; then if [ ! -f $lang_dir/transcript_tokens.txt ]; then
./local/add_silence_to_transcript.py \ ./local/convert_transcript_words_to_tokens.py \
--transcript $lang_dir/train.txt \ --lexicon $lang_dir/uniq_lexicon.txt \
--sil-word "!SIL" \ --transcript $lang_dir/transcript_words.txt \
--sil-prob 0.5 \
--seed 20210823 \
> $lang_dir/train_with_sil.txt
fi
if [ ! -f $lang_dir/corpus.txt ]; then
./local/convert_transcript_to_corpus.py \
--lexicon $lang_dir/lexicon.txt \
--transcript $lang_dir/train_with_sil.txt \
--oov "<UNK>" \ --oov "<UNK>" \
> $lang_dir/corpus.txt > $lang_dir/transcript_tokens.txt
fi fi
if [ ! -f $lang_dir/P.arpa ]; then if [ ! -f $lang_dir/P.arpa ]; then
./shared/make_kn_lm.py \ ./shared/make_kn_lm.py \
-ngram-order 2 \ -ngram-order 2 \
-text $lang_dir/corpus.txt \ -text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/P.arpa -lm $lang_dir/P.arpa
fi fi
@ -180,7 +173,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
# so that the two can share G.pt later. # so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/train.txt ]; then if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training" log "Generate data for BPE training"
files=$( files=$(
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt" find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -189,7 +182,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
) )
for f in ${files[@]}; do for f in ${files[@]}; do
cat $f | cut -d " " -f 2- cat $f | cut -d " " -f 2-
done > $lang_dir/train.txt done > $lang_dir/transcript_words.txt
fi fi
./local/train_bpe_model.py \ ./local/train_bpe_model.py \
@ -204,7 +197,7 @@ fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare G" log "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install # We assume you have installed kaldilm, if not, please install
# it using: pip install kaldilm # it using: pip install kaldilm
mkdir -p data/lm mkdir -p data/lm
@ -237,4 +230,4 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
done done
fi fi
cd data && ln -sfv lang_bpe_5000 lang_bpe # cd data && ln -sfv lang_bpe_5000 lang_bpe

View File

@ -106,7 +106,7 @@ class CtcTrainingGraphCompiler(object):
word_ids_list = [] word_ids_list = []
for text in texts: for text in texts:
word_ids = [] word_ids = []
for word in text.split(" "): for word in text.split():
if word in self.word_table: if word in self.word_table:
word_ids.append(self.word_table[word]) word_ids.append(self.word_table[word])
else: else:

View File

@ -95,7 +95,7 @@ class Lexicon(object):
""" """
Args: Args:
lang_dir: lang_dir:
Path to the lang director. It is expected to contain the following Path to the lang directory. It is expected to contain the following
files: files:
- tokens.txt - tokens.txt
- words.txt - words.txt

View File

@ -130,6 +130,11 @@ class MmiTrainingGraphCompiler(object):
transcript_fsa_with_self_loops, transcript_fsa_with_self_loops,
treat_epsilons_specially=False, treat_epsilons_specially=False,
) )
# CAUTION: Due to the presence of P,
# the resulting `num` may not be connected
num = k2.connect(num)
num = k2.arc_sort(num) num = k2.arc_sort(num)
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P]) ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
@ -160,7 +165,7 @@ class MmiTrainingGraphCompiler(object):
word_ids_list = [] word_ids_list = []
for text in texts: for text in texts:
word_ids = [] word_ids = []
for word in text.split(" "): for word in text.split():
if word in self.lexicon.word_table: if word in self.lexicon.word_table:
word_ids.append(self.lexicon.word_table[word]) word_ids.append(self.lexicon.word_table[word])
else: else:

174
test/test_mmi_graph_compiler.py Executable file
View File

@ -0,0 +1,174 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
You can run this file in one of the two ways:
(1) cd icefall; pytest test/test_mmi_graph_compiler.py
(2) cd icefall; ./test/test_mmi_graph_compiler.py
"""
import os
import shutil
import sys
import copy
from pathlib import Path
import k2
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
USING_PYTEST = "pytest" in sys.modules
ICEFALL_DIR = Path(__file__).resolve().parent.parent
print(ICEFALL_DIR)
def generate_test_data():
# if Path(TMP_DIR).exists():
# return
Path(TMP_DIR).mkdir(exist_ok=True)
lexicon = """
<UNK> SPN
cat c a t
at a t
at a a t
ac a c
ac a c c
"""
lexicon_filename = Path(TMP_DIR) / "lexicon.txt"
with open(lexicon_filename, "w") as f:
for line in lexicon.strip().split("\n"):
f.write(f"{line}\n")
transcript_words = """
cat at ta
at at cat ta
"""
transcript_words_filename = Path(TMP_DIR) / "transcript_words.txt"
with open(transcript_words_filename, "w") as f:
for line in transcript_words.strip().split("\n"):
f.write(f"{line}\n")
os.system(
f"""
cd {ICEFALL_DIR}/egs/librispeech/ASR
./local/generate_unique_lexicon.py --lang-dir {TMP_DIR}
./local/prepare_lang.py --lang-dir {TMP_DIR}
./local/convert_transcript_words_to_tokens.py \
--lexicon {TMP_DIR}/uniq_lexicon.txt \
--transcript {TMP_DIR}/transcript_words.txt \
--oov "<UNK>" \
> {TMP_DIR}/transcript_tokens.txt
shared/make_kn_lm.py \
-ngram-order 2 \
-text {TMP_DIR}/transcript_tokens.txt \
-lm {TMP_DIR}/P.arpa
python3 -m kaldilm \
--read-symbol-table="{TMP_DIR}/tokens.txt" \
--disambig-symbol='#0' \
--max-order=2 \
{TMP_DIR}/P.arpa > {TMP_DIR}/P.fst.txt
"""
)
def delete_test_data():
shutil.rmtree(TMP_DIR)
def mmi_graph_compiler_test():
graph_compiler = MmiTrainingGraphCompiler(lang_dir=TMP_DIR)
print(graph_compiler.device)
L_inv = graph_compiler.L_inv
L = k2.invert(L_inv)
L.labels_sym = graph_compiler.lexicon.token_table
L.aux_labels_sym = graph_compiler.lexicon.word_table
L.draw(f"{TMP_DIR}/L.svg", title="L")
L_inv.labels_sym = graph_compiler.lexicon.word_table
L_inv.aux_labels_sym = graph_compiler.lexicon.token_table
L_inv.draw(f"{TMP_DIR}/L_inv.svg", title="L")
ctc_topo_P = graph_compiler.ctc_topo_P
ctc_topo_P.labels_sym = copy.deepcopy(graph_compiler.lexicon.token_table)
ctc_topo_P.labels_sym._id2sym[0] = "<blk>"
ctc_topo_P.labels_sym._sym2id["<blk>"] = 0
ctc_topo_P.aux_labels_sym = graph_compiler.lexicon.token_table
ctc_topo_P.draw(f"{TMP_DIR}/ctc_topo_P.svg", title="ctc_topo_P")
print(ctc_topo_P.num_arcs)
print(k2.connect(ctc_topo_P).num_arcs)
with open(str(TMP_DIR) + "/P.fst.txt") as f:
# P is not an acceptor because there is
# a back-off state, whose incoming arcs
# have label #0 and aux_label 0 (i.e., <eps>).
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
P.labels_sym = graph_compiler.lexicon.token_table
P.aux_labels_sym = graph_compiler.lexicon.token_table
P.draw(f"{TMP_DIR}/P.svg", title="P")
ctc_topo = k2.ctc_topo(max(graph_compiler.lexicon.tokens), False)
ctc_topo.labels_sym = ctc_topo_P.labels_sym
ctc_topo.aux_labels_sym = graph_compiler.lexicon.token_table
ctc_topo.draw(f"{TMP_DIR}/ctc_topo.svg", title="ctc_topo")
print("p num arcs", P.num_arcs)
print("ctc_topo num arcs", ctc_topo.num_arcs)
print("ctc_topo_P num arcs", ctc_topo_P.num_arcs)
texts = ["cat at ac at", "at ac cat zoo", "cat zoo"]
transcript_fsa = graph_compiler.build_transcript_fsa(texts)
transcript_fsa[0].draw(f"{TMP_DIR}/cat_at_ac_at.svg", title="cat_at_ac_at")
transcript_fsa[1].draw(
f"{TMP_DIR}/at_ac_cat_zoo.svg", title="at_ac_cat_zoo"
)
transcript_fsa[2].draw(f"{TMP_DIR}/cat_zoo.svg", title="cat_zoo")
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
num_graphs[0].draw(
f"{TMP_DIR}/num_cat_at_ac_at.svg", title="num_cat_at_ac_at"
)
num_graphs[1].draw(
f"{TMP_DIR}/num_at_ac_cat_zoo.svg", title="num_at_ac_cat_zoo"
)
num_graphs[2].draw(f"{TMP_DIR}/num_cat_zoo.svg", title="num_cat_zoo")
den_graphs[0].draw(
f"{TMP_DIR}/den_cat_at_ac_at.svg", title="den_cat_at_ac_at"
)
den_graphs[2].draw(f"{TMP_DIR}/den_cat_zoo.svg", title="den_cat_zoo")
def test_main():
generate_test_data()
mmi_graph_compiler_test()
if USING_PYTEST:
delete_test_data()
def main():
test_main()
if __name__ == "__main__" and not USING_PYTEST:
main()