diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 9d9c2af1f..18fa3e69f 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -166,13 +166,6 @@ def get_parser(): """, ) - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - parser.add_argument( "--initial-lr", type=float, @@ -522,14 +515,6 @@ def compute_loss( nnet_output, encoder_memory, memory_mask = model( feature, supervisions, warmup=warmup ) - # logging.info('feature shape: {}'.format(feature.shape)) - # logging.info('nnet_output shape: {}'.format(nnet_output.shape)) - # logging.info('encoder_memory shape: {}'.format(encoder_memory.shape)) - # logging.info('memory_mask shape: {}'.format(memory_mask.shape)) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index fa179acc0..3ef7edc23 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -417,7 +417,6 @@ class TransformerEncoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, layer_dropout: float = 0.075, - activation: str = "relu", ) -> None: super(TransformerEncoderLayer, self).__init__() @@ -443,11 +442,6 @@ class TransformerEncoderLayer(nn.Module): self.dropout = nn.Dropout(dropout) - # def __setstate__(self, state): - # if "activation" not in state: - # state["activation"] = nn.functional.relu - # super(TransformerEncoderLayer, self).__setstate__(state) - def forward( self, src: torch.Tensor, @@ -539,7 +533,6 @@ class TransformerDecoderLayer(nn.Module): dim_feedforward: int = 2048, dropout: float = 0.1, layer_dropout: float = 0.075, - # activation: str = "relu", normalize_before: bool = True, ) -> None: super(TransformerDecoderLayer, self).__init__() @@ -564,11 +557,6 @@ class TransformerDecoderLayer(nn.Module): self.dropout = nn.Dropout(dropout) - # def __setstate__(self, state): - # if "activation" not in state: - # state["activation"] = nn.functional.relu - # super(TransformerDecoderLayer, self).__setstate__(state) - def forward( self, tgt: torch.Tensor, @@ -653,17 +641,6 @@ class TransformerDecoderLayer(nn.Module): return tgt -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) - - class TransformerEncoder(nn.Module): r"""TransformerEncoder is a stack of N encoder layers @@ -708,7 +685,7 @@ class TransformerEncoder(nn.Module): """ output = src - for i, mod in enumerate(self.layers): + for mod in self.layers: output = mod( output, src_mask=mask, @@ -769,7 +746,7 @@ class TransformerDecoder(nn.Module): """ output = tgt - for i, mod in enumerate(self.layers): + for mod in self.layers: output = mod( output, memory, diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 9a35750e0..c628dfd53 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -40,6 +40,13 @@ from icefall.lexicon import Lexicon def get_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--lm", + type=str, + default="G_3_gram", + help="""Stem name for LM used in HLG compiling. + """, + ) parser.add_argument( "--lang-dir", type=str, @@ -50,11 +57,13 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str) -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa: """ Args: lang_dir: The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + lm: + The language stem base name. Return: An FSA representing HLG. @@ -65,15 +74,15 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: H = k2.ctc_topo(max_token_id) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) - if Path("data/lm/G_3_gram.pt").is_file(): - logging.info("Loading pre-compiled G_3_gram") - d = torch.load("data/lm/G_3_gram.pt") + if Path(f"data/lm/{lm}.pt").is_file(): + logging.info(f"Loading pre-compiled {lm}") + d = torch.load(f"data/lm/{lm}.pt") G = k2.Fsa.from_dict(d) else: - logging.info("Loading G_3_gram.fst.txt") - with open("data/lm/G_3_gram.fst.txt") as f: + logging.info(f"Loading {lm}.fst.txt") + with open(f"data/lm/{lm}.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + torch.save(G.as_dict(), f"data/lm/{lm}.pt") first_token_disambig_id = lexicon.token_table["#0"] first_word_disambig_id = lexicon.word_table["#0"] @@ -144,7 +153,7 @@ def main(): logging.info(f"Processing {lang_dir}") - HLG = compile_HLG(lang_dir) + HLG = compile_HLG(lang_dir, args.lm) logging.info(f"Saving HLG.pt to {lang_dir}") torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index 51de46ae8..94784c4c4 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -17,10 +17,11 @@ import argparse -import inspect import logging + from functools import lru_cache from pathlib import Path +from typing import Any, Dict, Optional from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( @@ -28,7 +29,6 @@ from lhotse.dataset import ( CutMix, DynamicBucketingSampler, K2SpeechRecognitionDataset, - PrecomputedFeatures, SingleCutSampler, SpecAugment, ) @@ -140,7 +140,6 @@ class TedLiumAsrDataModule: "field: batch['supervisions']['cut'] with the cuts that " "were used to construct it.", ) - group.add_argument( "--num-workers", type=int, @@ -148,14 +147,12 @@ class TedLiumAsrDataModule: help="The number of training dataloader workers that " "collect the batches.", ) - group.add_argument( "--enable-spec-aug", type=str2bool, default=True, help="When enabled, use SpecAugment for training dataset.", ) - group.add_argument( "--spec-aug-time-warp-factor", type=int, @@ -165,16 +162,48 @@ class TedLiumAsrDataModule: "Larger values mean more warping. " "A value less than 1 means to disable time warp.", ) - group.add_argument( "--enable-musan", type=str2bool, default=True, help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", + "with training dataset.", ) - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) + + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=10, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=0.15, + p=0.9, + ) + ) + else: + logging.info("Disable SpecAugment") + logging.info("About to get Musan cuts") transforms = [] if self.args.enable_musan: @@ -204,42 +233,7 @@ class TedLiumAsrDataModule: ) ] + transforms - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - max_frames_mask_fraction=0.15, - p=0.9, - ) - ) - else: - logging.info("Disable SpecAugment") - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) if self.args.on_the_fly_feats: # NOTE: the PerturbSpeed transform should be added only if we # remove it from data prep stage. @@ -259,6 +253,12 @@ class TedLiumAsrDataModule: input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) + else: + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) if self.args.bucketing_sampler: logging.info("Using DynamicBucketingSampler.") @@ -276,6 +276,11 @@ class TedLiumAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, ) + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + logging.info("About to create train dataloader") train_dl = DataLoader( train, @@ -288,6 +293,7 @@ class TedLiumAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] if self.args.concatenate_cuts: transforms = [ @@ -310,11 +316,13 @@ class TedLiumAsrDataModule: cut_transforms=transforms, return_cuts=self.args.return_cuts, ) + valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, ) + logging.info("About to create dev dataloader") valid_dl = DataLoader( validate, @@ -326,25 +334,34 @@ class TedLiumAsrDataModule: return valid_dl - def test_dataloaders(self, cuts: CutSet) -> DataLoader: + def test_dataloaders(self, cuts_test: CutSet) -> DataLoader: + logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, + if self.args.on_the_fly_feats: + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + test = K2SpeechRecognitionDataset( + return_cuts=self.args.return_cuts, + ) + + test_sampler = DynamicBucketingSampler( + cuts_test, max_duration=self.args.max_duration, shuffle=False, ) + logging.debug("About to create test dataloader") test_dl = DataLoader( test, batch_size=None, - sampler=sampler, + sampler=test_sampler, num_workers=self.args.num_workers, + persistent_workers=False, ) return test_dl diff --git a/icefall/decode.py b/icefall/decode.py index f04ee368c..099e2d171 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -459,7 +459,8 @@ class Nbest(object): def one_best_decoding( lattice: k2.Fsa, use_double_scores: bool = True, -) -> k2.Fsa: + lm_scale_list: Optional[List[float]] = None, +) -> Union[k2.Fsa, Dict[str, k2.Fsa]]: """Get the best path from a lattice. Args: @@ -468,11 +469,28 @@ def one_best_decoding( use_double_scores: True to use double precision floating point in the computation. False to use single precision. + lm_scale_list: + A list of floats representing LM score scales. Return: An FsaVec containing linear paths. """ - best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores) - return best_path + + if lm_scale_list is not None: + + ans = dict() + saved_am_scores = lattice.scores - lattice.lm_scores + for lm_scale in lm_scale_list: + am_scores = saved_am_scores / lm_scale + lattice.scores = am_scores + lattice.lm_scores + + best_path = k2.shortest_path( + lattice, use_double_scores=use_double_scores + ) + key = f"lm_scale_{lm_scale}" + ans[key] = best_path + return ans + + return k2.shortest_path(lattice, use_double_scores=use_double_scores) def nbest_decoding( diff --git a/icefall/utils.py b/icefall/utils.py index c502cb4d8..143c79497 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -194,8 +194,16 @@ def encode_supervisions( supervision_segments = torch.stack( ( supervisions["sequence_idx"], - supervisions["start_frame"] // subsampling_factor, - supervisions["num_frames"] // subsampling_factor, + torch.div( + supervisions["start_frame"], + subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + subsampling_factor, + rounding_mode="floor", + ) ), 1, ).to(torch.int32)