From 3d0474c98639d06fe638d7dc98769e7fa7c70032 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 21 Apr 2022 11:49:52 +0800 Subject: [PATCH] Fix style issues. --- .../ASR/transducer_stateless3/model.py | 9 +- .../ASR/transducer_stateless3/train.py | 136 ++++-------------- 2 files changed, 32 insertions(+), 113 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless3/model.py b/egs/librispeech/ASR/transducer_stateless3/model.py index 7c4881c72..8ef3af5f4 100644 --- a/egs/librispeech/ASR/transducer_stateless3/model.py +++ b/egs/librispeech/ASR/transducer_stateless3/model.py @@ -18,8 +18,8 @@ import k2 import torch import torch.nn as nn +import torchaudio from encoder_interface import EncoderInterface -from scaling import ScaledLinear from icefall.utils import add_sos @@ -51,9 +51,10 @@ class Transducer(nn.Module): is (N, U) and its output shape is (N, U, decoder_dim). It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) diff --git a/egs/librispeech/ASR/transducer_stateless3/train.py b/egs/librispeech/ASR/transducer_stateless3/train.py index 80617847a..ba67ba8d3 100755 --- a/egs/librispeech/ASR/transducer_stateless3/train.py +++ b/egs/librispeech/ASR/transducer_stateless3/train.py @@ -21,22 +21,21 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless2/train.py \ +./transducer_stateless3/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir transducer_stateless3/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless2/train.py \ +./transducer_stateless3/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 0 \ - --use_fp16 1 \ - --exp-dir pruned_transducer_stateless2/exp \ + --exp-dir transducer_stateless3/exp \ --full-libri 1 \ --max-duration 550 @@ -138,7 +137,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="transducer_stateless3/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -156,7 +155,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need to be " + "changed.", ) parser.add_argument( @@ -183,40 +183,6 @@ def get_parser(): "2 means tri-gram", ) - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - parser.add_argument( "--seed", type=int, @@ -255,13 +221,6 @@ def get_parser(): """, ) - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - return parser @@ -318,7 +277,7 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 + "valid_interval": 3000, # For the 100h subset, use 1600 # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, @@ -506,7 +465,6 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, - warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -523,8 +481,6 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. """ device = model.device feature = batch["inputs"] @@ -540,27 +496,10 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + loss = model( x=feature, x_lens=feature_lens, y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 - if warmup < 1.0 - else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = ( - params.simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss ) assert loss.requires_grad == is_training @@ -574,8 +513,6 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() return loss, info @@ -622,7 +559,6 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -646,8 +582,6 @@ def train_one_epoch( Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -670,25 +604,22 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() optimizer.zero_grad() + loss.backward() + scheduler.step_batch(params.batch_idx_train) + optimizer.step() if params.print_diagnostics and batch_idx == 5: return @@ -706,7 +637,6 @@ def train_one_epoch( optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, - scaler=scaler, rank=rank, ) del params.cur_batch_idx @@ -883,11 +813,6 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs): scheduler.step_epoch(epoch) fix_random_seed(params.seed + epoch) @@ -906,7 +831,6 @@ def run(rank, world_size, args): sp=sp, train_dl=train_dl, valid_dl=valid_dl, - scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -922,7 +846,6 @@ def run(rank, world_size, args): optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, - scaler=scaler, rank=rank, ) @@ -949,21 +872,16 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=0.0, - ) + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + optimizer.zero_grad() loss.backward() optimizer.step() - optimizer.zero_grad() except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error(