Minor fixes.

This commit is contained in:
Fangjun Kuang 2021-08-16 17:03:05 +08:00
parent 03242b3328
commit 56319b0903
5 changed files with 236 additions and 92 deletions

View File

@ -13,7 +13,9 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from tdnn_lstm_ctc.model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -58,6 +60,26 @@ def get_parser():
help="Should various information be logged in tensorboard.", help="Should various information be logged in tensorboard.",
) )
parser.add_argument(
"--use-ali-model",
type=str2bool,
default=True,
help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py "
"and you have some checkpoints inside the directory "
"tdnn_lstm_ctc/exp_bpe_500 ."
"It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt "
"as the pre-trained alignment model",
)
parser.add_argument(
"--ali-model-epoch",
type=int,
default=19,
help="If --use-ali-model is True, load "
"tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as "
"the alignment model."
"Used only if --use-ali-model is True.",
)
# TODO: add extra arguments and support DDP training. # TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add # Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished. # DDP training once single GPU training is finished.
@ -117,24 +139,21 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_mmi/exp"), "exp_dir": Path("conformer_mmi/exp_500"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe_500"),
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 1e-6, "weight_decay": 1e-6,
"subsampling_factor": 4, "subsampling_factor": 4,
"start_epoch": 0, "start_epoch": 0,
"num_epochs": 10, "num_epochs": 50,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 10, "log_interval": 10,
# It takes about 10 minutes (1 GPU, max_duration=200) "reset_interval": 200,
# to run a validation process. "valid_interval": 10,
# For the 100 h subset, there are 85617 batches.
# For the 960 h dataset, there are 843723 batches
"valid_interval": 8000,
"use_pruned_intersect": False, "use_pruned_intersect": False,
"den_scale": 1.0, "den_scale": 1.0,
# #
@ -242,6 +261,7 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module],
batch: dict, batch: dict,
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: BpeMmiTrainingGraphCompiler,
is_training: bool, is_training: bool,
@ -274,6 +294,22 @@ def compute_loss(
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is [N, T, C]
if ali_model is not None and params.batch_idx_train < 4000:
feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T]
ali_model_output = ali_model(feature)
# subsampling is done slightly differently, may be small length
# differences.
min_len = min(ali_model_output.shape[1], nnet_output.shape[1])
# scale less than one so it will be encouraged
# to mimic ali_model's output
ali_model_scale = 500.0 / (params.batch_idx_train + 500)
# Use clone() here or log-softmax backprop will fail.
nnet_output = nnet_output.clone()
nnet_output[:, :min_len, :] += (
ali_model_scale * ali_model_output[:, :min_len, :]
)
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by
@ -337,6 +373,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module],
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: BpeMmiTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
@ -354,6 +391,7 @@ def compute_validation_loss(
loss, mmi_loss, att_loss = compute_loss( loss, mmi_loss, att_loss = compute_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=False, is_training=False,
@ -394,6 +432,7 @@ def compute_validation_loss(
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
ali_model: Optional[nn.Module],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
graph_compiler: BpeMmiTrainingGraphCompiler, graph_compiler: BpeMmiTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
@ -412,6 +451,9 @@ def train_one_epoch(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The model for training. The model for training.
ali_model:
The force alignment model for training. It is from
tdnn_lstm_ctc/train_bpe.py
optimizer: optimizer:
The optimizer we are using. The optimizer we are using.
graph_compiler: graph_compiler:
@ -432,7 +474,8 @@ def train_one_epoch(
tot_att_loss = 0.0 tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches tot_frames = 0.0 # sum of frames over all batches
params.tot_loss = 0.0
params.tot_frames = 0.0
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
@ -440,6 +483,7 @@ def train_one_epoch(
loss, mmi_loss, att_loss = compute_loss( loss, mmi_loss, att_loss = compute_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
batch=batch, batch=batch,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
is_training=True, is_training=True,
@ -450,6 +494,7 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0)
optimizer.step() optimizer.step()
loss_cpu = loss.detach().cpu().item() loss_cpu = loss.detach().cpu().item()
@ -461,6 +506,9 @@ def train_one_epoch(
tot_mmi_loss += mmi_loss_cpu tot_mmi_loss += mmi_loss_cpu
tot_att_loss += att_loss_cpu tot_att_loss += att_loss_cpu
params.tot_frames += params.train_frames
params.tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames tot_avg_loss = tot_loss / tot_frames
tot_avg_mmi_loss = tot_mmi_loss / tot_frames tot_avg_mmi_loss = tot_mmi_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames tot_avg_att_loss = tot_att_loss / tot_frames
@ -509,11 +557,18 @@ def train_one_epoch(
tot_avg_loss, tot_avg_loss,
params.batch_idx_train, params.batch_idx_train,
) )
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
tot_loss = 0.0 # sum of losses over all batches
tot_mmi_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0 # sum of frames over all batches
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss( compute_validation_loss(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
@ -544,7 +599,7 @@ def train_one_epoch(
params.batch_idx_train, params.batch_idx_train,
) )
params.train_loss = tot_loss / tot_frames params.train_loss = params.tot_loss / params.tot_frames
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch params.best_train_epoch = params.cur_epoch
@ -624,6 +679,32 @@ def run(rank, world_size, args):
if checkpoints: if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
if args.use_ali_model:
ali_model = TdnnLstm(
num_features=params.feature_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
)
ali_model_fname = Path(
f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt"
)
assert (
ali_model_fname.is_file()
), f"ali model filename {ali_model_fname} does not exist!"
ali_model.load_state_dict(
torch.load(ali_model_fname, map_location="cpu")["model"]
)
ali_model.to(device)
ali_model.eval()
ali_model.requires_grad_(False)
logging.info(f"Use ali_model: {ali_model_fname}")
else:
ali_model = None
logging.info("No ali_model")
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders() train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders() valid_dl = librispeech.valid_dataloaders()
@ -646,6 +727,7 @@ def run(rank, world_size, args):
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
ali_model=ali_model,
optimizer=optimizer, optimizer=optimizer,
graph_compiler=graph_compiler, graph_compiler=graph_compiler,
train_dl=train_dl, train_dl=train_dl,

View File

@ -1,18 +1,18 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
This script compiles HLG from This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lexicon.txt - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from L_disambig.pt - L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt - G, the LM, built from data/lm/G_3_gram.fst.txt
The generated HLG is saved in data/lm/HLG.pt (phone based) The generated HLG is saved in $lang_dir/HLG.pt
or data/lm/HLG_bpe.pt (BPE based)
""" """
import argparse
import logging import logging
from pathlib import Path from pathlib import Path
@ -22,11 +22,23 @@ import torch
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa: def compile_HLG(lang_dir: str) -> k2.Fsa:
""" """
Args: Args:
lang_dir: lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe. The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return: Return:
An FSA representing HLG. An FSA representing HLG.
@ -104,17 +116,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
def main(): def main():
for d in ["data/lang_phone", "data/lang_bpe"]: args = get_args()
d = Path(d) lang_dir = Path(args.lang_dir)
logging.info(f"Processing {d}")
if (d / "HLG.pt").is_file(): if (lang_dir / "HLG.pt").is_file():
logging.info(f"{d}/HLG.pt already exists - skipping") logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
continue return
HLG = compile_HLG(d) logging.info(f"Processing {lang_dir}")
logging.info(f"Saving HLG.pt to {d}")
torch.save(HLG.as_dict(), f"{d}/HLG.pt") HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,12 +3,13 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
""" """
This script takes as inputs the following two files:
- data/lang_bpe/bpe.model, This script takes as input `lang_dir`, which should contain::
- data/lang_bpe/words.txt
and generates the following files in the directory data/lang_bpe: - lang_dir/bpe.model,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt - lexicon.txt
- lexicon_disambig.txt - lexicon_disambig.txt
@ -17,6 +18,7 @@ and generates the following files in the directory data/lang_bpe:
- tokens.txt - tokens.txt
""" """
import argparse
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -141,8 +143,22 @@ def generate_lexicon(
return lexicon, token2id return lexicon, token2id
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
return parser.parse_args()
def main(): def main():
lang_dir = Path("data/lang_bpe") args = get_args()
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bpe.model" model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
@ -189,15 +205,6 @@ def main():
torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(lang_dir / "L.svg", title="L")
L_disambig.draw(lang_dir / "L_disambig.svg", title="L_disambig")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,10 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""
This script takes as input "data/lang/bpe/train.txt"
and generates "data/lang/bpe/bep.model".
"""
# You can install sentencepiece via: # You can install sentencepiece via:
# #
# pip install sentencepiece # pip install sentencepiece
@ -14,17 +9,41 @@ and generates "data/lang/bpe/bep.model".
# #
# Please install a version >=0.1.96 # Please install a version >=0.1.96
import argparse
import shutil import shutil
from pathlib import Path from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the training corpus: train.txt.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def main(): def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
model_type = "unigram" model_type = "unigram"
vocab_size = 5000
model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}" model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = "data/lang_bpe/train.txt" train_text = f"{lang_dir}/train.txt"
character_coverage = 1.0 character_coverage = 1.0
input_sentence_size = 100000000 input_sentence_size = 100000000
@ -49,10 +68,7 @@ def main():
eos_id=-1, eos_id=-1,
) )
sp = spm.SentencePieceProcessor(model_file=str(model_file)) shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
vocab_size = sp.vocab_size()
shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -36,8 +36,17 @@ dl_dir=$PWD/download
. shared/parse_options.sh || exit 1 . shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_500, data/lang_bpe_1000,
# and data/lang_bpe_5000.
vocab_sizes=(
500
1000
5000
)
# All generated files by this script are saved in "data" # All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data mkdir -p data
log() { log() {
@ -116,14 +125,18 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
fi fi
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang" log "State 6: Prepare BPE based lang"
mkdir -p data/lang_bpe
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
mkdir -p $lang_dir
# We reuse words.txt from phone based lexicon # We reuse words.txt from phone based lexicon
# so that the two can share G.pt later. # so that the two can share G.pt later.
cp data/lang_phone/words.txt data/lang_bpe/ cp data/lang_phone/words.txt $lang_dir
if [ ! -f data/lang_bpe/train.txt ]; then if [ ! -f $lang_dir/train.txt ]; then
log "Generate data for BPE training" log "Generate data for BPE training"
files=$( files=$(
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -132,40 +145,48 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
) )
for f in ${files[@]}; do for f in ${files[@]}; do
cat $f | cut -d " " -f 2- cat $f | cut -d " " -f 2-
done > data/lang_bpe/train.txt done > $lang_dir/train.txt
fi fi
python3 ./local/train_bpe_model.py ./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size
if [ ! -f data/lang_bpe/L_disambig.pt ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py ./local/prepare_lang_bpe.py --lang-dir $lang_dir
fi fi
done
fi fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare bigram P" log "Stage 7: Prepare bigram P"
if [ ! -f data/lang_bpe/corpus.txt ]; then
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/corpus.txt ]; then
./local/convert_transcript_to_corpus.py \ ./local/convert_transcript_to_corpus.py \
--lexicon data/lang_bpe/lexicon.txt \ --lexicon data/lang_bpe/lexicon.txt \
--transcript data/lang_bpe/train.txt \ --transcript data/lang_bpe/train.txt \
--oov "<UNK>" \ --oov "<UNK>" \
> data/lang_bpe/corpus.txt > $lang_dir/corpus.txt
fi fi
if [ ! -f data/lang_bpe/P.arpa ]; then if [ ! -f $lang_dir/P.arpa ]; then
./shared/make_kn_lm.py \ ./shared/make_kn_lm.py \
-ngram-order 2 \ -ngram-order 2 \
-text data/lang_bpe/corpus.txt \ -text $lang_dir/corpus.txt \
-lm data/lang_bpe/P.arpa -lm $lang_dir/P.arpa
fi fi
if [ ! -f data/lang_bpe/P.fst.txt ]; then if [ ! -f $lang_dir/P.fst.txt ]; then
python3 -m kaldilm \ python3 -m kaldilm \
--read-symbol-table="data/lang_bpe/tokens.txt" \ --read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \ --disambig-symbol='#0' \
--max-order=2 \ --max-order=2 \
data/lang_bpe/P.arpa > data/lang_bpe/P.fst.txt $lang_dir/P.arpa > $lang_dir/P.fst.txt
fi fi
done
fi fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
@ -195,5 +216,10 @@ fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compile HLG" log "Stage 9: Compile HLG"
python3 ./local/compile_hlg.py ./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
done
fi fi