From 9e6bd0f07c29b467618cde424fedefe81faf3e6f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 24 Sep 2021 20:34:34 +0800 Subject: [PATCH] Minor fixes. --- egs/librispeech/ASR/conformer_ctc/train.py | 2 +- egs/librispeech/ASR/conformer_mmi/train.py | 249 +++++++++------------ egs/librispeech/ASR/prepare.sh | 2 - icefall/mmi.py | 10 +- icefall/mmi_graph_compiler.py | 5 + 5 files changed, 119 insertions(+), 149 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 80b2d924a..8c1fc9595 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -99,7 +99,7 @@ def get_params() -> AttributeDict: """Return a dict containing training parameters. All training related parameters that are not passed from the commandline - is saved in the variable `params`. + are saved in the variable `params`. Commandline options are merged into `params` after they are parsed, so you can also access them via `params`. diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index f11291bbf..6decbc189 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -1,4 +1,21 @@ #!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import logging @@ -11,21 +28,20 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed -from tdnn_lstm_ctc.model import TdnnLstm from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss +from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler from icefall.utils import ( AttributeDict, encode_supervisions, @@ -61,28 +77,22 @@ def get_parser(): ) parser.add_argument( - "--use-ali-model", - type=str2bool, - default=True, - help="If true, we assume that you have run tdnn_lstm_ctc/train_bpe.py " - "and you have some checkpoints inside the directory " - "tdnn_lstm_ctc/exp_bpe_500 ." - "It will use tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt " - "as the pre-trained alignment model", - ) - parser.add_argument( - "--ali-model-epoch", + "--num-epochs", type=int, - default=19, - help="If --use-ali-model is True, load " - "tdnn_lstm_ctc/exp_bpe_500/epoch-{ali-model-epoch}.pt as " - "the alignment model." - "Used only if --use-ali-model is True.", + default=50, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_mmi/exp/epoch-{start_epoch-1}.pt + """, ) - # TODO: add extra arguments and support DDP training. - # Currently, only single GPU training is implemented. Will add - # DDP training once single GPU training is finished. return parser @@ -90,7 +100,7 @@ def get_params() -> AttributeDict: """Return a dict containing training parameters. All training related parameters that are not passed from the commandline - is saved in the variable `params`. + are saved in the variable `params`. Commandline options are merged into `params` after they are parsed, so you can also access them via `params`. @@ -103,20 +113,6 @@ def get_params() -> AttributeDict: - lang_dir: It contains language related input files such as "lexicon.txt" - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - num_epochs: Number of epochs to train. - - best_train_loss: Best training loss so far. It is used to select the model that has the lowest training loss. It is updated during the training. @@ -135,36 +131,60 @@ def get_params() -> AttributeDict: - log_interval: Print training loss if batch_idx % log_interval` is 0 - - valid_interval: Run validation if batch_idx % valid_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Whether to do batch normalization for the + input features. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - weight_decay: The weight_decay for the optimizer. + + - lr_factor: The lr_factor for Noam optimizer. + + - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( { "exp_dir": Path("conformer_mmi/exp_500"), "lang_dir": Path("data/lang_bpe_500"), - "feature_dim": 80, - "weight_decay": 1e-6, - "subsampling_factor": 4, - "start_epoch": 0, - "num_epochs": 50, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 10, + "log_interval": 50, "reset_interval": 200, - "valid_interval": 10, - "use_pruned_intersect": False, - "den_scale": 1.0, - # - "att_rate": 0.7, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 8, "num_decoder_layers": 6, - "is_espnet_structure": True, - "use_feat_batchnorm": True, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + "att_rate": 0.7, + # parameters for Noam + "weight_decay": 1e-6, "lr_factor": 5.0, "warm_step": 80000, + "use_pruned_intersect": False, + "den_scale": 1.0, } ) @@ -261,13 +281,12 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: nn.Module, - ali_model: Optional[nn.Module], batch: dict, - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, is_training: bool, ): """ - Compute MMI loss given the model and its inputs. + Compute LF-MMI loss given the model and its inputs. Args: params: @@ -278,7 +297,9 @@ def compute_loss( A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` for the content in it. graph_compiler: - It is used to build num_graphs and den_graphs. + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. is_training: True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it @@ -286,54 +307,34 @@ def compute_loss( """ device = graph_compiler.device feature = batch["inputs"] - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] - if ali_model is not None and params.batch_idx_train < 4000: - feature = feature.permute(0, 2, 1) # [N, T, C]->[N, C, T] - ali_model_output = ali_model(feature) - # subsampling is done slightly differently, may be small length - # differences. - min_len = min(ali_model_output.shape[1], nnet_output.shape[1]) - # scale less than one so it will be encouraged - # to mimic ali_model's output - ali_model_scale = 500.0 / (params.batch_idx_train + 500) + # nnet_output is (N, T, C) - # Use clone() here or log-softmax backprop will fail. - nnet_output = nnet_output.clone() + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `LFMMILoss.forward()` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) - nnet_output[:, :min_len, :] += ( - ali_model_scale * ali_model_output[:, :min_len, :] - ) + loss_fn = LFMMILoss( + graph_compiler=graph_compiler, + use_pruned_intersect=params.use_pruned_intersect, + den_scale=params.den_scale, + ) - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in LFMMILoss - # - # TODO: If params.use_pruned_intersect is True, there is no - # need to call encode_supervisions - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - loss_fn = LFMMILoss( - graph_compiler=graph_compiler, - den_scale=params.den_scale, - use_pruned_intersect=params.use_pruned_intersect, - ) - - mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts) if params.att_rate != 0.0: token_ids = graph_compiler.texts_to_ids(texts) @@ -373,8 +374,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: nn.Module, - ali_model: Optional[nn.Module], - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> None: @@ -391,7 +391,6 @@ def compute_validation_loss( loss, mmi_loss, att_loss = compute_loss( params=params, model=model, - ali_model=ali_model, batch=batch, graph_compiler=graph_compiler, is_training=False, @@ -432,9 +431,8 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, model: nn.Module, - ali_model: Optional[nn.Module], optimizer: torch.optim.Optimizer, - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, @@ -451,9 +449,6 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. - ali_model: - The force alignment model for training. It is from - tdnn_lstm_ctc/train_bpe.py optimizer: The optimizer we are using. graph_compiler: @@ -483,7 +478,6 @@ def train_one_epoch( loss, mmi_loss, att_loss = compute_loss( params=params, model=model, - ali_model=ali_model, batch=batch, graph_compiler=graph_compiler, is_training=True, @@ -494,7 +488,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2.0) + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() @@ -519,7 +513,7 @@ def train_one_epoch( f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, " f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg mmi loss: {tot_avg_mmi_loss:.4f}, " + f"total avg mmiloss: {tot_avg_mmi_loss:.4f}, " f"total avg att loss: {tot_avg_att_loss:.4f}, " f"total avg loss: {tot_avg_loss:.4f}, " f"batch size: {batch_size}" @@ -568,7 +562,6 @@ def train_one_epoch( compute_validation_loss( params=params, model=model, - ali_model=ali_model, graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, @@ -576,10 +569,10 @@ def train_one_epoch( model.train() logging.info( f"Epoch {params.cur_epoch}, " - f"valid mmi loss {params.valid_mmi_loss:.4f}, " - f"valid att loss {params.valid_att_loss:.4f}, " - f"valid loss {params.valid_loss:.4f}, " - f"best valid loss: {params.best_valid_loss:.4f}, " + f"valid mmi loss {params.valid_mmi_loss:.4f}," + f"valid att loss {params.valid_att_loss:.4f}," + f"valid loss {params.valid_loss:.4f}," + f" best valid loss: {params.best_valid_loss:.4f} " f"best valid epoch: {params.best_valid_epoch}" ) if tb_writer is not None: @@ -642,11 +635,13 @@ def run(rank, world_size, args): if torch.cuda.is_available(): device = torch.device("cuda", rank) - graph_compiler = BpeMmiTrainingGraphCompiler( + graph_compiler = MmiTrainingGraphCompiler( params.lang_dir, + uniq_filename="lexicon.txt", device=device, - sos_token="", - eos_token="", + oov="", + sos_id=1, + eos_id=1, ) logging.info("About to create model") @@ -658,7 +653,6 @@ def run(rank, world_size, args): subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, vgg_frontend=False, - is_espnet_structure=params.is_espnet_structure, use_feat_batchnorm=params.use_feat_batchnorm, ) @@ -679,32 +673,6 @@ def run(rank, world_size, args): if checkpoints: optimizer.load_state_dict(checkpoints["optimizer"]) - if args.use_ali_model: - ali_model = TdnnLstm( - num_features=params.feature_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - ) - - ali_model_fname = Path( - f"tdnn_lstm_ctc/exp_bpe_500/epoch-{args.ali_model_epoch}.pt" - ) - assert ( - ali_model_fname.is_file() - ), f"ali model filename {ali_model_fname} does not exist!" - - ali_model.load_state_dict( - torch.load(ali_model_fname, map_location="cpu")["model"] - ) - ali_model.to(device) - - ali_model.eval() - ali_model.requires_grad_(False) - logging.info(f"Use ali_model: {ali_model_fname}") - else: - ali_model = None - logging.info("No ali_model") - librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() @@ -727,7 +695,6 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, - ali_model=ali_model, optimizer=optimizer, graph_compiler=graph_compiler, train_dl=train_dl, diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 1965dc491..c1a532fc1 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -227,5 +227,3 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then ./local/compile_hlg.py --lang-dir $lang_dir done fi - -cd data && ln -sfv lang_bpe_500 lang_bpe diff --git a/icefall/mmi.py b/icefall/mmi.py index ec5d07dfe..f9ba46df9 100644 --- a/icefall/mmi.py +++ b/icefall/mmi.py @@ -4,13 +4,13 @@ import k2 import torch from torch import nn -from icefall.bpe_mmi_graph_compiler import BpeMmiTrainingGraphCompiler +from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler def _compute_mmi_loss_exact_optimized( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -98,7 +98,7 @@ def _compute_mmi_loss_exact_optimized( def _compute_mmi_loss_exact_non_optimized( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -133,7 +133,7 @@ def _compute_mmi_loss_exact_non_optimized( def _compute_mmi_loss_pruned( dense_fsa_vec: k2.DenseFsaVec, texts: List[str], - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, den_scale: float = 1.0, ) -> torch.Tensor: """ @@ -184,7 +184,7 @@ class LFMMILoss(nn.Module): def __init__( self, - graph_compiler: BpeMmiTrainingGraphCompiler, + graph_compiler: MmiTrainingGraphCompiler, use_pruned_intersect: bool = False, den_scale: float = 1.0, ): diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py index 43f2a092a..0d901227d 100644 --- a/icefall/mmi_graph_compiler.py +++ b/icefall/mmi_graph_compiler.py @@ -15,6 +15,8 @@ class MmiTrainingGraphCompiler(object): uniq_filename: str = "uniq_lexicon.txt", device: Union[str, torch.device] = "cpu", oov: str = "", + sos_id: int = 1, + eos_id: int = 1, ): """ Args: @@ -45,6 +47,8 @@ class MmiTrainingGraphCompiler(object): self.L_inv = self.lexicon.L_inv.to(self.device) self.oov_id = self.lexicon.word_table[oov] + self.sos_id = sos_id + self.eos_id = eos_id self.build_ctc_topo_P() @@ -93,6 +97,7 @@ class MmiTrainingGraphCompiler(object): ).invert() self.ctc_topo_P = k2.arc_sort(ctc_topo_P) + logging.info(f"ctc_topo_P num_arcs: {self.ctc_topo_P.num_arcs}") def compile( self, texts: Iterable[str], replicate_den: bool = True