diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 1f52370fd..dad2b2b18 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -51,8 +51,8 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -221,6 +221,13 @@ def get_parser(): """, ) + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + return parser @@ -411,6 +418,7 @@ def save_checkpoint( model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, rank: int = 0, ) -> None: """Save model, optimizer, scheduler and training stats to file. @@ -424,6 +432,8 @@ def save_checkpoint( The optimizer used in the training. sampler: The sampler for the training dataset. + scaler: + The scaler used for mix precision training. """ if rank != 0: return @@ -434,6 +444,7 @@ def save_checkpoint( params=params, optimizer=optimizer, sampler=sampler, + scaler=scaler, rank=rank, ) @@ -473,6 +484,7 @@ def compute_loss( feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 + feature = feature.to(device) supervisions = batch["supervisions"] @@ -547,6 +559,7 @@ def train_one_epoch( sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -568,6 +581,8 @@ def train_one_epoch( Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -610,14 +625,17 @@ def train_one_epoch( and tb_writer is not None and params.batch_idx_train % (params.log_interval * 5) == 0 ): - deltas = optim_step_and_measure_param_change(model, optimizer) + deltas = optim_step_and_measure_param_change( + model, optimizer, scaler=scaler + ) tb_writer.add_scalars( "train/relative_param_change_per_minibatch", deltas, global_step=params.batch_idx_train, ) else: - optimizer.step() + scaler.step(optimizer) + scaler.update() cur_batch_idx = params.get("cur_batch_idx", 0) @@ -629,20 +647,23 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) + with torch.autocast( + device_type=model.device.type, enabled=params.use_fp16 + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - loss.backward() + scaler.scale(loss).backward() maybe_log_weights("train/param_norms") maybe_log_gradients("train/grad_norms") @@ -662,6 +683,7 @@ def train_one_epoch( params=params, optimizer=optimizer, sampler=train_dl.sampler, + scaler=scaler, rank=rank, ) del params.cur_batch_idx @@ -831,6 +853,11 @@ def run(rank, world_size, args): params=params, ) + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) @@ -857,6 +884,7 @@ def run(rank, world_size, args): tb_writer=tb_writer, world_size=world_size, rank=rank, + scaler=scaler, ) save_checkpoint( @@ -864,6 +892,7 @@ def run(rank, world_size, args): model=model, optimizer=optimizer, sampler=train_dl.sampler, + scaler=scaler, rank=rank, ) @@ -891,15 +920,17 @@ def scan_pessimistic_batches_for_oom( batch = train_dl.dataset[cuts] try: optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) + with torch.autocast( + device_type=model.device.type, enabled=params.use_fp16 + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() except RuntimeError as e: if "CUDA out of memory" in str(e):