diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index 53dbb8211..dff98a954 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -101,6 +101,9 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): # Note: for ./lstm_transducer_stateless/lstm.py, the formula is # T = ((num_frames - 3) // 2 - 1) // 2 + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) if T < len(tokens): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8927be227..3f27736b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -59,7 +59,6 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from zipformer import Zipformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -71,6 +70,7 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints @@ -79,9 +79,9 @@ from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) -from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ @@ -89,14 +89,12 @@ LRSchedulerType = Union[ ] -def set_batch_count( - model: Union[nn.Module, DDP], batch_count: float -) -> None: +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module for module in model.modules(): - if hasattr(module, 'batch_count'): + if hasattr(module, "batch_count"): module.batch_count = batch_count @@ -126,7 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -134,7 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): type=str, default="192,192,192,192,192", help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""" + not the same as embedding dimension.""", ) parser.add_argument( @@ -143,7 +141,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="256,256,256,256,256", help="Unmasked dimensions in the encoders, relates to augmentation during training. " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse." + " worse.", ) parser.add_argument( @@ -248,10 +246,7 @@ def get_parser(): ) parser.add_argument( - "--base-lr", - type=float, - default=0.05, - help="The base learning rate." + "--base-lr", type=float, default=0.05, help="The base learning rate." ) parser.add_argument( @@ -451,11 +446,14 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Zipformer and Transformer def to_int_tuple(s: str): - return tuple(map(int, s.split(','))) + return tuple(map(int, s.split(","))) + encoder = Zipformer( num_features=params.feature_dim, output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), encoder_dims=to_int_tuple(params.encoder_dims), attention_dim=to_int_tuple(params.attention_dims), encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), @@ -479,7 +477,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -496,7 +494,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -682,18 +680,17 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = ( - simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss - ) + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -873,12 +870,16 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] @@ -888,8 +889,12 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " + - (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -905,12 +910,15 @@ def train_one_epoch( ) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) - - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -921,7 +929,9 @@ def train_one_epoch( ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -997,12 +1007,11 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], - find_unused_parameters=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), - lr=params.base_lr, - clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1043,7 +1052,34 @@ def run(rank, world_size, args): # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True train_cuts = train_cuts.filter(remove_short_and_long_utt) @@ -1071,8 +1107,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, - init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1193,7 +1228,9 @@ def scan_pessimistic_batches_for_oom( ) display_and_save_batch(batch, params=params, sp=sp) raise - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) def main(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index c14066d38..023dec97d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1828,6 +1828,7 @@ def _test_zipformer_main(): torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) + assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) f[0].sum().backward() c.eval() f = c( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index b4177d3f0..2603bb854 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -90,12 +90,7 @@ from icefall.checkpoint import ( from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.utils import ( - AttributeDict, - MetricsTracker, - setup_logger, - str2bool, -) +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler @@ -1045,7 +1040,9 @@ def train_one_epoch( params.best_train_loss = params.train_loss -def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: +def filter_short_and_long_utterances( + cuts: CutSet, sp: spm.SentencePieceProcessor +) -> CutSet: def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # @@ -1055,7 +1052,34 @@ def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True cuts = cuts.filter(remove_short_and_long_utt) @@ -1162,7 +1186,7 @@ def run(rank, world_size, args): train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_other_500_cuts() - train_cuts = filter_short_and_long_utterances(train_cuts) + train_cuts = filter_short_and_long_utterances(train_cuts, sp) gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir) # XL 10k hours @@ -1179,7 +1203,7 @@ def run(rank, world_size, args): logging.info("Using the S subset of GigaSpeech (250 hours)") train_giga_cuts = gigaspeech.train_S_cuts() - train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts) + train_giga_cuts = filter_short_and_long_utterances(train_giga_cuts, sp) train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: