diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py index 8927be227..e8d577df6 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless7/train.py @@ -27,8 +27,8 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --num-epochs 30 \ --start-epoch 1 \ --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 300 + --max-duration 750 \ + --training-subset L # For mix precision training: @@ -38,9 +38,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --start-epoch 1 \ --use-fp16 1 \ --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 550 - + --max-duration 750 """ @@ -54,12 +52,10 @@ from typing import Any, Dict, Optional, Tuple, Union import k2 import optim -import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from zipformer import Zipformer +from asr_datamodule import WenetSpeechAsrDataModule from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -71,17 +67,20 @@ from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, update_averaged_model, ) -from icefall.hooks import register_inf_check_hooks from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool LRSchedulerType = Union[ @@ -89,14 +88,12 @@ LRSchedulerType = Union[ ] -def set_batch_count( - model: Union[nn.Module, DDP], batch_count: float -) -> None: +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module for module in model.modules(): - if hasattr(module, 'batch_count'): + if hasattr(module, "batch_count"): module.batch_count = batch_count @@ -126,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated" + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -134,7 +131,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): type=str, default="192,192,192,192,192", help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""" + not the same as embedding dimension.""", ) parser.add_argument( @@ -143,7 +140,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): default="256,256,256,256,256", help="Unmasked dimensions in the encoders, relates to augmentation during training. " "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse." + " worse.", ) parser.add_argument( @@ -241,17 +238,17 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, ) parser.add_argument( - "--base-lr", - type=float, - default=0.05, - help="The base learning rate." + "--base-lr", type=float, default=0.05, help="The base learning rate." ) parser.add_argument( @@ -451,11 +448,14 @@ def get_params() -> AttributeDict: def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Zipformer and Transformer def to_int_tuple(s: str): - return tuple(map(int, s.split(','))) + return tuple(map(int, s.split(","))) + encoder = Zipformer( num_features=params.feature_dim, output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple(params.zipformer_downsampling_factors), + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), encoder_dims=to_int_tuple(params.encoder_dims), attention_dim=to_int_tuple(params.attention_dims), encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), @@ -479,7 +479,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -496,7 +496,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dims.split(',')[-1]), + encoder_dim=int(params.encoder_dims.split(",")[-1]), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -567,9 +567,6 @@ def load_checkpoint_if_available( if "cur_epoch" in saved_params: params["start_epoch"] = saved_params["cur_epoch"] - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -626,7 +623,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, + graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: @@ -665,7 +662,8 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) + + y = graph_compiler.texts_to_ids(texts) y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): @@ -682,18 +680,17 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = ( - simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss - ) + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -715,7 +712,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, + graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, ) -> MetricsTracker: @@ -728,7 +725,7 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, - sp=sp, + graph_compiler=graph_compiler, batch=batch, is_training=False, ) @@ -751,7 +748,7 @@ def train_one_epoch( model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, + graph_compiler: CharCtcTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, @@ -795,13 +792,7 @@ def train_one_epoch( tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -810,7 +801,7 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, - sp=sp, + graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -827,7 +818,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params) raise if params.print_diagnostics and batch_idx == 5: @@ -848,7 +839,6 @@ def train_one_epoch( params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -861,7 +851,6 @@ def train_one_epoch( scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -873,12 +862,16 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: - raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] @@ -888,8 +881,12 @@ def train_one_epoch( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " + - (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + f"lr: {cur_lr:.2e}, " + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -905,23 +902,28 @@ def train_one_epoch( ) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) - - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, model=model, - sp=sp, + graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, ) model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) if tb_writer is not None: valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train @@ -948,8 +950,6 @@ def run(rank, world_size, args): """ params = get_params() params.update(vars(args)) - if params.full_libri is False: - params.valid_interval = 1600 fix_random_seed(params.seed) if world_size > 1: @@ -968,12 +968,14 @@ def run(rank, world_size, args): device = torch.device("cuda", rank) logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) @@ -997,12 +999,11 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], - find_unused_parameters=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), - lr=params.base_lr, - clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1027,26 +1028,26 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - librispeech = LibriSpeechAsrDataModule(args) + wenetspeech = WenetSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = wenetspeech.train_cuts() + valid_cuts = wenetspeech.valid_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds + # Keep only utterances with duration between 1 second and 19 seconds # - # Caution: There is a reason to select 20.0 here. Please see + # Caution: There is a reason to select 19.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 20.0 + return 1.0 <= c.duration <= 19.0 train_cuts = train_cuts.filter(remove_short_and_long_utt) + valid_dl = wenetspeech.valid_dataloaders(valid_cuts) + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint # saved in the middle of an epoch @@ -1054,25 +1055,20 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = librispeech.train_dataloaders( + train_dl = wenetspeech.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - - if not params.print_diagnostics: + if False and not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, train_dl=train_dl, optimizer=optimizer, - sp=sp, + graph_compiler=graph_compiler, params=params, ) - scaler = GradScaler(enabled=params.use_fp16, - init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1093,7 +1089,7 @@ def run(rank, world_size, args): model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, - sp=sp, + graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, @@ -1127,7 +1123,6 @@ def run(rank, world_size, args): def display_and_save_batch( batch: dict, params: AttributeDict, - sp: spm.SentencePieceProcessor, ) -> None: """Display the batch statistics and save the batch into disk. @@ -1137,8 +1132,6 @@ def display_and_save_batch( for the content in it. params: Parameters for training. See :func:`get_params`. - sp: - The BPE model. """ from lhotse.utils import uuid4 @@ -1146,13 +1139,13 @@ def display_and_save_batch( logging.info(f"Saving batch to {filename}") torch.save(batch, filename) - supervisions = batch["supervisions"] features = batch["inputs"] logging.info(f"features shape: {features.shape}") - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) + texts = batch["supervisions"]["text"] + num_tokens = sum(len(i) for i in texts) + logging.info(f"num tokens: {num_tokens}") @@ -1160,7 +1153,7 @@ def scan_pessimistic_batches_for_oom( model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, + graph_compiler: CharCtcTrainingGraphCompiler, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches @@ -1176,7 +1169,7 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, - sp=sp, + graph_compiler=graph_compiler, batch=batch, is_training=True, ) @@ -1191,15 +1184,18 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, sp=sp) + display_and_save_batch(batch, params=params) raise - logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + WenetSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) args.exp_dir = Path(args.exp_dir) world_size = args.world_size