diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index a9178c8b3..81f6df790 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -141,17 +141,21 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=self.simple_lm_proj(decoder_out), - am=self.simple_am_proj(encoder_out), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) + lm=self.simple_lm_proj(decoder_out) + am=self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) # ranges : [B, T, prune_range] ranges = k2.get_rnnt_prune_ranges( @@ -176,13 +180,14 @@ class Transducer(nn.Module): logits = self.joiner(am_pruned, lm_pruned, project_input=False) - pruned_loss = k2.rnnt_loss_pruned( - logits=logits, - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index b9ea0def6..d08fa15b5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -29,7 +29,16 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --full-libri 1 \ --max-duration 300 +# For mix precision training: +./pruned_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 0 \ + --use_fp16 1 \ + --exp-dir pruned_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 550 """ @@ -58,6 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eve, Eden from torch import Tensor +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -249,6 +259,13 @@ def get_parser(): """, ) + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + return parser @@ -447,6 +464,7 @@ def save_checkpoint( optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -460,6 +478,8 @@ def save_checkpoint( The optimizer used in the training. sampler: The sampler for the training dataset. + scaler: + The scaler used for mix precision training. """ if rank != 0: return @@ -471,6 +491,7 @@ def save_checkpoint( optimizer=optimizer, scheduler=scheduler, sampler=sampler, + scaler=scaler, rank=rank, ) @@ -599,6 +620,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -622,6 +644,8 @@ def train_one_epoch( Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -644,22 +668,24 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step) - ) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step) + ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - loss.backward() + scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) - optimizer.step() + scaler.step(optimizer) + scaler.update() optimizer.zero_grad() if params.print_diagnostics and batch_idx == 5: @@ -676,6 +702,7 @@ def train_one_epoch( optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, + scaler=scaler, rank=rank, ) del params.cur_batch_idx @@ -695,7 +722,9 @@ def train_one_epoch( ) if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train @@ -850,6 +879,11 @@ def run(rank, world_size, args): params=params, ) + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + for epoch in range(params.start_epoch, params.num_epochs): scheduler.step_epoch(epoch) fix_random_seed(params.seed + epoch) @@ -869,6 +903,7 @@ def run(rank, world_size, args): sp=sp, train_dl=train_dl, valid_dl=valid_dl, + scaler=scaler, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -884,6 +919,7 @@ def run(rank, world_size, args): optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, + scaler=scaler, rank=rank, ) @@ -913,14 +949,15 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup = 0.0 - ) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup = 0.0 + ) loss.backward() optimizer.step() optimizer.zero_grad()