From 56319b090350299b3d499c81c35a00a89de146ad Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 16 Aug 2021 17:03:05 +0800 Subject: [PATCH] Minor fixes. --- egs/librispeech/ASR/conformer_mmi/train.py | 102 +++++++++++++++-- egs/librispeech/ASR/local/compile_hlg.py | 43 ++++--- egs/librispeech/ASR/local/prepare_lang_bpe.py | 35 +++--- egs/librispeech/ASR/local/train_bpe_model.py | 40 +++++-- egs/librispeech/ASR/prepare.sh | 108 +++++++++++------- 5 files changed, 236 insertions(+), 92 deletions(-) diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 810a0a4df..f11291bbf 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -13,7 +13,9 @@ import torch.multiprocessing as mp import torch.nn as nn from conformer import Conformer from lhotse.utils import fix_random_seed +from tdnn_lstm_ctc.model import TdnnLstm from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -58,6 +60,26 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--use-ali-model", + type=str2bool, + default=True, + help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py " + "and you have some checkpoints inside the directory " + "tdnn_lstm_ctc/exp_bpe_500 ." + "It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt " + "as the pre-trained alignment model", + ) + parser.add_argument( + "--ali-model-epoch", + type=int, + default=19, + help="If --use-ali-model is True, load " + "tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as " + "the alignment model." + "Used only if --use-ali-model is True.", + ) + # TODO: add extra arguments and support DDP training. # Currently, only single GPU training is implemented. Will add # DDP training once single GPU training is finished. @@ -117,24 +139,21 @@ def get_params() -> AttributeDict: """ params = AttributeDict( { - "exp_dir": Path("conformer_mmi/exp"), - "lang_dir": Path("data/lang_bpe"), + "exp_dir": Path("conformer_mmi/exp_500"), + "lang_dir": Path("data/lang_bpe_500"), "feature_dim": 80, "weight_decay": 1e-6, "subsampling_factor": 4, "start_epoch": 0, - "num_epochs": 10, + "num_epochs": 50, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, - # It takes about 10 minutes (1 GPU, max_duration=200) - # to run a validation process. - # For the 100 h subset, there are 85617 batches. - # For the 960 h dataset, there are 843723 batches - "valid_interval": 8000, + "reset_interval": 200, + "valid_interval": 10, "use_pruned_intersect": False, "den_scale": 1.0, # @@ -242,6 +261,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: nn.Module, + ali_model: Optional[nn.Module], batch: dict, graph_compiler: BpeMmiTrainingGraphCompiler, is_training: bool, @@ -274,6 +294,22 @@ def compute_loss( with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) # nnet_output is [N, T, C] + if ali_model is not None and params.batch_idx_train < 4000: + feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T] + ali_model_output = ali_model(feature) + # subsampling is done slightly differently, may be small length + # differences. + min_len = min(ali_model_output.shape[1], nnet_output.shape[1]) + # scale less than one so it will be encouraged + # to mimic ali_model's output + ali_model_scale = 500.0 / (params.batch_idx_train + 500) + + # Use clone() here or log-softmax backprop will fail. + nnet_output = nnet_output.clone() + + nnet_output[:, :min_len, :] += ( + ali_model_scale * ali_model_output[:, :min_len, :] + ) # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -337,6 +373,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: nn.Module, + ali_model: Optional[nn.Module], graph_compiler: BpeMmiTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -354,6 +391,7 @@ def compute_validation_loss( loss, mmi_loss, att_loss = compute_loss( params=params, model=model, + ali_model=ali_model, batch=batch, graph_compiler=graph_compiler, is_training=False, @@ -394,6 +432,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, model: nn.Module, + ali_model: Optional[nn.Module], optimizer: torch.optim.Optimizer, graph_compiler: BpeMmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, @@ -412,6 +451,9 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. + ali_model: + The force alignment model for training. It is from + tdnn_lstm_ctc/train_bpe.py optimizer: The optimizer we are using. graph_compiler: @@ -432,7 +474,8 @@ def train_one_epoch( tot_att_loss = 0.0 tot_frames = 0.0 # sum of frames over all batches - + params.tot_loss = 0.0 + params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -440,6 +483,7 @@ def train_one_epoch( loss, mmi_loss, att_loss = compute_loss( params=params, model=model, + ali_model=ali_model, batch=batch, graph_compiler=graph_compiler, is_training=True, @@ -450,6 +494,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() + clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() @@ -461,6 +506,9 @@ def train_one_epoch( tot_mmi_loss += mmi_loss_cpu tot_att_loss += att_loss_cpu + params.tot_frames += params.train_frames + params.tot_loss += loss_cpu + tot_avg_loss = tot_loss / tot_frames tot_avg_mmi_loss = tot_mmi_loss / tot_frames tot_avg_att_loss = tot_att_loss / tot_frames @@ -509,11 +557,18 @@ def train_one_epoch( tot_avg_loss, params.batch_idx_train, ) + if batch_idx > 0 and batch_idx % params.reset_interval == 0: + tot_loss = 0.0 # sum of losses over all batches + tot_mmi_loss = 0.0 + tot_att_loss = 0.0 + + tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: compute_validation_loss( params=params, model=model, + ali_model=ali_model, graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, @@ -544,7 +599,7 @@ def train_one_epoch( params.batch_idx_train, ) - params.train_loss = tot_loss / tot_frames + params.train_loss = params.tot_loss / params.tot_frames if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch @@ -624,6 +679,32 @@ def run(rank, world_size, args): if checkpoints: optimizer.load_state_dict(checkpoints["optimizer"]) + if args.use_ali_model: + ali_model = TdnnLstm( + num_features=params.feature_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + ) + + ali_model_fname = Path( + f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt" + ) + assert ( + ali_model_fname.is_file() + ), f"ali model filename {ali_model_fname} does not exist!" + + ali_model.load_state_dict( + torch.load(ali_model_fname, map_location="cpu")["model"] + ) + ali_model.to(device) + + ali_model.eval() + ali_model.requires_grad_(False) + logging.info(f"Use ali_model: {ali_model_fname}") + else: + ali_model = None + logging.info("No ali_model") + librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() @@ -646,6 +727,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + ali_model=ali_model, optimizer=optimizer, graph_compiler=graph_compiler, train_dl=train_dl, diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index b30402161..9f28bb74d 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -1,18 +1,18 @@ #!/usr/bin/env python3 """ -This script compiles HLG from +This script takes as input lang_dir and generates HLG from - - H, the ctc topology, built from tokens contained in lexicon.txt - - L, the lexicon, built from L_disambig.pt + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt Caution: We use a lexicon that contains disambiguation symbols - G, the LM, built from data/lm/G_3_gram.fst.txt -The generated HLG is saved in data/lm/HLG.pt (phone based) -or data/lm/HLG_bpe.pt (BPE based) +The generated HLG is saved in $lang_dir/HLG.pt """ +import argparse import logging from pathlib import Path @@ -22,11 +22,23 @@ import torch from icefall.lexicon import Lexicon +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + def compile_HLG(lang_dir: str) -> k2.Fsa: """ Args: lang_dir: - The language directory, e.g., data/lang_phone or data/lang_bpe. + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. Return: An FSA representing HLG. @@ -104,17 +116,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: def main(): - for d in ["data/lang_phone", "data/lang_bpe"]: - d = Path(d) - logging.info(f"Processing {d}") + args = get_args() + lang_dir = Path(args.lang_dir) - if (d / "HLG.pt").is_file(): - logging.info(f"{d}/HLG.pt already exists - skipping") - continue + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return - HLG = compile_HLG(d) - logging.info(f"Saving HLG.pt to {d}") - torch.save(HLG.as_dict(), f"{d}/HLG.pt") + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lang_dir) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index e31220d9b..68b8db966 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -3,12 +3,13 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This script takes as inputs the following two files: - - data/lang_bpe/bpe.model, - - data/lang_bpe/words.txt +This script takes as input `lang_dir`, which should contain:: -and generates the following files in the directory data/lang_bpe: + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: - lexicon.txt - lexicon_disambig.txt @@ -17,6 +18,7 @@ and generates the following files in the directory data/lang_bpe: - tokens.txt """ +import argparse from pathlib import Path from typing import Dict, List, Tuple @@ -141,8 +143,22 @@ def generate_lexicon( return lexicon, token2id +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + return parser.parse_args() + + def main(): - lang_dir = Path("data/lang_bpe") + args = get_args() + lang_dir = Path(args.lang_dir) model_file = lang_dir / "bpe.model" word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") @@ -189,15 +205,6 @@ def main(): torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - if False: - # Just for debugging, will remove it - L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - L_disambig.labels_sym = L.labels_sym - L_disambig.aux_labels_sym = L.aux_labels_sym - L.draw(lang_dir / "L.svg", title="L") - L_disambig.draw(lang_dir / "L_disambig.svg", title="L_disambig") - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 59746ad9a..9872a7c6a 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -1,10 +1,5 @@ #!/usr/bin/env python3 -""" -This script takes as input "data/lang/bpe/train.txt" -and generates "data/lang/bpe/bep.model". -""" - # You can install sentencepiece via: # # pip install sentencepiece @@ -14,17 +9,41 @@ and generates "data/lang/bpe/bep.model". # # Please install a version >=0.1.96 +import argparse import shutil from pathlib import Path import sentencepiece as spm +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the training corpus: train.txt. + The generated bpe.model is saved to this directory. + """, + ) + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + model_type = "unigram" - vocab_size = 5000 - model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}" - train_text = "data/lang_bpe/train.txt" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = f"{lang_dir}/train.txt" character_coverage = 1.0 input_sentence_size = 100000000 @@ -49,10 +68,7 @@ def main(): eos_id=-1, ) - sp = spm.SentencePieceProcessor(model_file=str(model_file)) - vocab_size = sp.vocab_size() - - shutil.copyfile(model_file, "data/lang_bpe/bpe.model") + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index c8e093177..6479973bf 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -36,8 +36,17 @@ dl_dir=$PWD/download . shared/parse_options.sh || exit 1 +# vocab size for sentence piece models. +# It will generate data/lang_bpe_500, data/lang_bpe_1000, +# and data/lang_bpe_5000. +vocab_sizes=( + 500 + 1000 + 5000 +) -# All generated files by this script are saved in "data" +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. mkdir -p data log() { @@ -116,56 +125,68 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then fi fi + if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "State 6: Prepare BPE based lang" - mkdir -p data/lang_bpe - # We reuse words.txt from phone based lexicon - # so that the two can share G.pt later. - cp data/lang_phone/words.txt data/lang_bpe/ - if [ ! -f data/lang_bpe/train.txt ]; then - log "Generate data for BPE training" - files=$( - find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" - find "data/LibriSpeech/train-clean-360" -name "*.trans.txt" - find "data/LibriSpeech/train-other-500" -name "*.trans.txt" - ) - for f in ${files[@]}; do - cat $f | cut -d " " -f 2- - done > data/lang_bpe/train.txt - fi + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir - python3 ./local/train_bpe_model.py + if [ ! -f $lang_dir/train.txt ]; then + log "Generate data for BPE training" + files=$( + find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" + find "data/LibriSpeech/train-clean-360" -name "*.trans.txt" + find "data/LibriSpeech/train-other-500" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $lang_dir/train.txt + fi - if [ ! -f data/lang_bpe/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py - fi + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + fi + done fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Prepare bigram P" - if [ ! -f data/lang_bpe/corpus.txt ]; then - ./local/convert_transcript_to_corpus.py \ - --lexicon data/lang_bpe/lexicon.txt \ - --transcript data/lang_bpe/train.txt \ - --oov "" \ - > data/lang_bpe/corpus.txt - fi - if [ ! -f data/lang_bpe/P.arpa ]; then - ./shared/make_kn_lm.py \ - -ngram-order 2 \ - -text data/lang_bpe/corpus.txt \ - -lm data/lang_bpe/P.arpa - fi + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} - if [ ! -f data/lang_bpe/P.fst.txt ]; then - python3 -m kaldilm \ - --read-symbol-table="data/lang_bpe/tokens.txt" \ - --disambig-symbol='#0' \ - --max-order=2 \ - data/lang_bpe/P.arpa > data/lang_bpe/P.fst.txt - fi + if [ ! -f $lang_dir/corpus.txt ]; then + ./local/convert_transcript_to_corpus.py \ + --lexicon data/lang_bpe/lexicon.txt \ + --transcript data/lang_bpe/train.txt \ + --oov "" \ + > $lang_dir/corpus.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/corpus.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa > $lang_dir/P.fst.txt + fi + done fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then @@ -195,5 +216,10 @@ fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then log "Stage 9: Compile HLG" - python3 ./local/compile_hlg.py + ./local/compile_hlg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + done fi