minor fix

This commit is contained in:
Dongji Gao 2023-09-18 23:21:58 -04:00
parent 092fb4766d
commit 914cdce956
7 changed files with 473 additions and 87 deletions

View File

@ -69,7 +69,7 @@ def get_parser():
parser.add_argument(
"--otc-token", type=str, default="▁<star>", help="OTC token",
)
)
parser.add_argument(
"--epoch",
@ -173,17 +173,11 @@ def get_parser():
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc2/exp",
help="The experiment dir",
"--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
"--lang-dir", type=str, default="data/lang_bpe_500", help="The lang dir",
)
parser.add_argument(
@ -230,10 +224,7 @@ def get_parser():
)
parser.add_argument(
"--rnn-lm-hidden-dim",
type=int,
default=2048,
help="Hidden dim of the model",
"--rnn-lm-hidden-dim", type=int, default=2048, help="Hidden dim of the model",
)
parser.add_argument(
@ -449,11 +440,7 @@ def decode_one_batch(
return {key: hyps}
if params.method == "ctc-greedy-search":
hyps, _ = ctc_greedy_search(
nnet_output,
memory,
memory_key_padding_mask,
)
hyps, _ = ctc_greedy_search(nnet_output, memory, memory_key_padding_mask,)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(hyps)
@ -521,16 +508,12 @@ def decode_one_batch(
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None,
)
# TODO: pass `lattice` instead of `rescored_lattice` to
# `rescore_with_attention_decoder`
@ -548,9 +531,7 @@ def decode_one_batch(
elif params.method == "rnn-lm":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None,
)
best_path_dict = rescore_with_rnn_lm(
@ -734,6 +715,8 @@ def main():
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
assert "" not in args.otc_token
args.otc_token = f"{args.otc_token}"
params = get_params()
params.update(vars(args))
@ -769,11 +752,7 @@ def main():
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
H = k2.ctc_topo(max_token=max_token_id, modified=False, device=device,)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
@ -943,8 +922,7 @@ def main():
)
if params.rnn_lm_avg == 1:
load_checkpoint(
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt",
rnn_lm_model,
f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", rnn_lm_model,
)
rnn_lm_model.to(device)
else:

View File

@ -26,23 +26,18 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./conformer_ctc2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--manifest-dir data/ssl \
--train-manifest librispeech_cuts_train-clean-100_0.17_0.17_0.17.jsonl.gz \
--exp-dir conformer_ctc2/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./conformer_ctc2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir conformer_ctc2/exp \
--full-libri 1 \
--max-duration 550
--lang-dir data/lang_bpe_200 \
--otc-token "<star>" \
--allow-bypass-arc true \
--allow-self-loop-arc true \
--initial-bypass-weight -19 \
--initial-self-loop-weight 3.75 \
--bypass-weight-decay 0.975 \
--self-loop-weight-decay 0.999 \
--show-alignment true
"""
@ -260,12 +255,12 @@ def get_parser():
)
parser.add_argument(
"--otc-token", type=str, default="<star>", help="OTC token",
"--otc-token", type=str, default="_<star>", help="OTC token",
)
parser.add_argument(
"--allow-bypass-arc",
type=bool,
type=str2bool,
default=True,
help="""Whether to add bypass arc to training graph for substitution
and insertion errors (wrong or extra words in the transcript).""",
@ -273,7 +268,7 @@ def get_parser():
parser.add_argument(
"--allow-self-loop-arc",
type=bool,
type=str2bool,
default=True,
help="""Whether to self-loop bypass arc to training graph for deletion errors
(missing words in the transcript).""",
@ -311,7 +306,7 @@ def get_parser():
parser.add_argument(
"--show-alignment",
type=bool,
type=str2bool,
default=True,
help="Whether to print OTC alignment during training",
)
@ -560,7 +555,7 @@ def compute_loss(
feature, supervisions, warmup=warmup
)
# Set the probability of OTC token as the average of non-blank tokens
# under the assumption that blank is the first and
# under the assumption that blank is the first and
# OTC token is the last token in tokens.txt
_, _, V = nnet_output.shape
@ -592,9 +587,7 @@ def compute_loss(
self_loop_weight=self_loop_weight,
)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output, supervision_segments, allow_truncate=3,
)
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments, allow_truncate=3,)
otc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
@ -632,7 +625,7 @@ def compute_loss(
for index, utt_id in enumerate(utt_ids):
verbatim_text = verbatim_texts[index]
utt_id = utt_ids[index]
lattice = k2.intersect_dense(
decoding_graph, dense_fsa_vec, params.beam_size,
)
@ -642,7 +635,7 @@ def compute_loss(
hyp_ids = get_texts(best_path)[index]
hyp_text_list = [graph_compiler.token_table[i] for i in hyp_ids]
hyp_text = " ".join(hyp_text_list)
logging.info(f"[utterance id]: {utt_id}")
logging.info(f"[verbatim text]: {verbatim_text}")
logging.info(f"[best alignment]: {hyp_text}")
@ -959,7 +952,6 @@ def run(rank, world_size, args):
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
@ -1084,6 +1076,8 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
assert "" not in args.otc_token
args.otc_token = f"{args.otc_token}"
world_size = args.world_size
assert world_size >= 1
@ -1092,6 +1086,7 @@ def main():
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

View File

@ -18,14 +18,12 @@ def get_args():
)
parser.add_argument(
"--otc-token",
type=str,
help="OTC token to be added to BPE model",
"--otc-token", type=str, help="OTC token to be added to words.txt",
)
return parser.parse_args()
def main():
args = get_args()
lang_dir = Path(args.lang_dir)

View File

@ -90,26 +90,18 @@ def modify_cut_text(
sub_index_set = set()
ins_index_set = set()
# preprocessing
for i in range(len(text_list)):
prob = random.random()
if prob <= del_ratio:
del_index_set.add(i)
elif prob <= del_ratio + sub_ratio:
sub_index_set.add(i)
elif prob <= del_ratio + sub_ratio + ins_ratio:
ins_index_set.add(i)
# We follow the order: deletion -> substitution -> insertion
for i, token in enumerate(text_list):
for token in text_list:
marked_token = token
modified_token = token
if i in del_index_set:
prob = random.random()
if prob <= del_ratio:
marked_token = f"-{token}-"
modified_token = ""
elif i in sub_index_set or i in ins_index_set:
if i in sub_index_set:
elif prob <= del_ratio + sub_ratio + ins_ratio:
if prob <= del_ratio + sub_ratio:
marked_token = f"[{token}]"
else:
marked_verbatim_text_list.append(marked_token)

View File

@ -0,0 +1,411 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
2. Generate tokens.txt, the token table mapping a token to a unique integer.
3. Generate words.txt, the word table mapping a word to a unique integer.
4. Generate L.pt, in k2 format. It can be loaded by
d = torch.load("L.pt")
lexicon = k2.Fsa.from_dict(d)
5. Generate L_disambig.pt, in k2 format.
"""
import argparse
import math
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
import k2
import torch
from icefall.lexicon import read_lexicon, write_lexicon
from icefall.utils import str2bool
Lexicon = List[Tuple[str, List[str]]]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain a file lexicon.txt.
Generated files by this script are saved into this directory.
""",
)
parser.add_argument(
"--debug",
type=str2bool,
default=False,
help="""True for debugging, which will generate
a visualization of the lexicon FST.
Caution: If your lexicon contains hundreds of thousands
of lines, please set it to False!
""",
)
return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
"""Write a symbol to ID mapping to a file.
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items():
f.write(f"{sym} {i}\n")
def get_tokens(lexicon: Lexicon) -> List[str]:
"""Get tokens from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique tokens.
"""
ans = set()
for _, tokens in lexicon:
ans.update(tokens)
sorted_ans = sorted(list(ans))
return sorted_ans
def get_words(lexicon: Lexicon) -> List[str]:
"""Get words from a lexicon.
Args:
lexicon:
It is the return value of :func:`read_lexicon`.
Returns:
Return a list of unique words.
"""
ans = set()
for word, _ in lexicon:
ans.add(word)
sorted_ans = sorted(list(ans))
return sorted_ans
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
at the ends of tokens to ensure that all pronunciations are different,
and that none is a prefix of another.
See also add_lex_disambig.pl from kaldi.
Args:
lexicon:
It is returned by :func:`read_lexicon`.
Returns:
Return a tuple with two elements:
- The output lexicon with disambiguation symbols
- The ID of the max disambiguation symbol that appears
in the lexicon
"""
# (1) Work out the count of each token-sequence in the
# lexicon.
count = defaultdict(int)
for _, tokens in lexicon:
count[" ".join(tokens)] += 1
# (2) For each left sub-sequence of each token-sequence, note down
# that it exists (for identifying prefixes of longer strings).
issubseq = defaultdict(int)
for _, tokens in lexicon:
tokens = tokens.copy()
tokens.pop()
while tokens:
issubseq[" ".join(tokens)] = 1
tokens.pop()
# (3) For each entry in the lexicon:
# if the token sequence is unique and is not a
# prefix of another word, no disambig symbol.
# Else output #1, or #2, #3, ... if the same token-seq
# has already been assigned a disambig symbol.
ans = []
# We start with #1 since #0 has its own purpose
first_allowed_disambig = 1
max_disambig = first_allowed_disambig - 1
last_used_disambig_symbol_of = defaultdict(int)
for word, tokens in lexicon:
tokenseq = " ".join(tokens)
assert tokenseq != ""
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
ans.append((word, tokens))
continue
cur_disambig = last_used_disambig_symbol_of[tokenseq]
if cur_disambig == 0:
cur_disambig = first_allowed_disambig
else:
cur_disambig += 1
if cur_disambig > max_disambig:
max_disambig = cur_disambig
last_used_disambig_symbol_of[tokenseq] = cur_disambig
tokenseq += f" #{cur_disambig}"
ans.append((word, tokenseq.split()))
return ans, max_disambig
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
"""Generate ID maps, i.e., map a symbol to a unique ID.
Args:
symbols:
A list of unique symbols.
Returns:
A dict containing the mapping between symbols and IDs.
"""
return {sym: i for i, sym in enumerate(symbols)}
def add_self_loops(
arcs: List[List[Any]], disambig_token: int, disambig_word: int
) -> List[List[Any]]:
"""Adds self-loops to states of an FST to propagate disambiguation symbols
through it. They are added on each state with non-epsilon output symbols
on at least one arc out of the state.
See also fstaddselfloops.pl from Kaldi. One difference is that
Kaldi uses OpenFst style FSTs and it has multiple final states.
This function uses k2 style FSTs and it does not need to add self-loops
to the final state.
The input label of a self-loop is `disambig_token`, while the output
label is `disambig_word`.
Args:
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_token:
It is the token ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
for s in states_needs_self_loops:
ans.append([s, s, disambig_token, disambig_word, 0])
return arcs + ans
def lexicon_to_fst(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
sil_token: str = "SIL",
sil_prob: float = 0.5,
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format) with optional silence at
the beginning and end of each word.
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
sil_token:
The silence token.
sil_prob:
The probability for adding a silence at the beginning and end
of the word.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
assert sil_prob > 0.0 and sil_prob < 1.0
# CAUTION: we use score, i.e, negative cost.
sil_score = math.log(sil_prob)
no_sil_score = math.log(1.0 - sil_prob)
start_state = 0
loop_state = 1 # words enter and leave from here
sil_state = 2 # words terminate here when followed by silence; this state
# has a silence transition to loop_state.
next_state = 3 # the next un-allocated state, will be incremented as we go.
arcs = []
assert token2id["<eps>"] == 0
assert word2id["<eps>"] == 0
eps = 0
sil_token = token2id[sil_token]
arcs.append([start_state, loop_state, eps, eps, no_sil_score])
arcs.append([start_state, sil_state, eps, eps, sil_score])
arcs.append([sil_state, loop_state, sil_token, eps, 0])
for word, tokens in lexicon:
assert len(tokens) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
tokens = [token2id[i] for i in tokens]
for i in range(len(tokens) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, tokens[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last token of this word
# It has two out-going arcs, one to the loop state,
# the other one to the sil_state.
i = len(tokens) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score])
arcs.append([cur_state, sil_state, tokens[i], w, sil_score])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
lexicon_filename = lang_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5
lexicon = read_lexicon(lexicon_filename)
tokens = get_tokens(lexicon)
words = get_words(lexicon)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in tokens
tokens.append(f"#{i}")
assert "<eps>" not in tokens
tokens = ["<eps>"] + tokens
assert "<eps>" not in words
assert "#0" not in words
assert "<s>" not in words
assert "</s>" not in words
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
token2id = generate_id_map(tokens)
word2id = generate_id_map(words)
write_mapping(lang_dir / "tokens.txt", token2id)
write_mapping(lang_dir / "words.txt", word2id)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst(
lexicon,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
)
L_disambig = lexicon_to_fst(
lexicon_disambig,
token2id=token2id,
word2id=word2id,
sil_token=sil_token,
sil_prob=sil_prob,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if args.debug:
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L.labels_sym = labels_sym
L.aux_labels_sym = aux_labels_sym
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
L_disambig.labels_sym = labels_sym
L_disambig.aux_labels_sym = aux_labels_sym
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -53,6 +53,13 @@ def get_args():
help="Path to bpe.model",
)
parser.add_argument(
"--otc-token",
required=True,
type=str,
help="OTC token",
)
return parser.parse_args()
@ -67,6 +74,7 @@ def main():
sp.load(str(args.bpe_model))
word_pieces = set(sp.id_to_piece(list(range(sp.vocab_size()))))
word_pieces.add(f"{args.otc_token}")
for word, pieces in lexicon:
for p in pieces:
if p not in word_pieces:

View File

@ -37,9 +37,6 @@ feature_dir="data/ssl"
lang_dir="data/lang"
lm_dir="data/lm"
log_dir="tdnn_lstm_ctc/prepare/log"
mkdir -p "${log_dir}"
. ./cmd.sh
. shared/parse_options.sh || exit 1
@ -90,7 +87,12 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# to $dl_dir/LibriSpeech
mkdir -p data/manifests
if [ ! -e data/manifests/.librispeech.done ]; then
lhotse prepare librispeech -j $nj "${dl_dir}/LibriSpeech" "${manifests_dir}"
lhotse prepare librispeech -j ${nj} \
-p dev-clean \
-p dev-other \
-p test-clean \
-p test-other \
-p train-clean-100 "${dl_dir}/LibriSpeech" "${manifests_dir}"
touch data/manifests/.librispeech.done
fi
fi
@ -165,18 +167,20 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
if [ ! -f ${bpe_lang_dir}/L_disambig.pt ]; then
./local/prepare_otc_lang_bpe.py \
--lang-dir "${bpe_lang_dir}"
--lang-dir "${bpe_lang_dir}" \
--otc-token "${otc_token}"
log "Validating ${bpe_lang_dir}/lexicon.txt"
./local/validate_bpe_lexicon.py \
--lexicon ${bpe_lang_dir}/lexicon.txt \
--bpe-model ${bpe_lang_dir}/bpe.model
--bpe-model ${bpe_lang_dir}/bpe.model \
--otc-token "${otc_token}"
fi
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 4: Prepare G"
log "Stage 5: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
@ -214,4 +218,4 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
--lm-dir "${lm_dir}" \
--lang-dir "${bpe_lang_dir}"
done
fi
fi