From 85b6450a8ab10e8f0e2f64178e9d7f54d4670f5a Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 19 Jun 2023 07:45:57 +0800 Subject: [PATCH] Remove old code --- egs/libriheavy/LM/zipformer1/train.py | 35 ++++++++++++--------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index b7af8f331..5f38c6ce4 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -632,11 +632,6 @@ def compute_validation_loss( if world_size > 1: tot_loss.reduce(loss.device) - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - return tot_loss @@ -654,7 +649,7 @@ def train( rank: int = 0, batch_idx_offset: int = 0, ) -> None: - """Train the model for one epoch. + """Train the model until we have trained on the specified --num-tokens. The training loss from the mean of all frames is saved in `params.train_loss`. It runs the validation process every @@ -778,13 +773,13 @@ def train( if params.num_tokens_seen > params.num_tokens: break - if batch_idx % 100 == 0 and params.use_fp16: + if params.batch_idx_train % 100 == 0 and params.use_fp16: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: @@ -795,14 +790,14 @@ def train( save_bad_model() raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") - if batch_idx % params.log_interval == 0: + if params.batch_idx_train % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 logging.info( f"Epoch {params.num_tokens_seen / params.tokens_per_epoch:.3f}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], tokens: {tokens_seen} " + f"batch {params.batch_idx_train}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], tokens: {params.num_tokens_seen} " f"lr: {cur_lr:.2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) @@ -824,7 +819,7 @@ def train( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -930,26 +925,26 @@ def run(rank, world_size, args): register_inf_check_hooks(model) - train = LmDataset(params.train_file_list, - bytes_per_segment=params.bytes_per_segment) + train_data = LmDataset(params.train_file_list, + bytes_per_segment=params.bytes_per_segment) - params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress. + params.tokens_per_epoch = train_data.num_tokens() # helps us figure out epoch progress. batch_size = params.batch_size // (6 if params.print_diagnostics else 1) train_dl = torch.utils.data.DataLoader( - dataset=train, + dataset=train_data, batch_size=batch_size, num_workers=params.num_workers, drop_last=True) - valid = LmDataset(params.valid_file_list, - bytes_per_segment=params.bytes_per_segment, - training=False) + valid_data = LmDataset(params.valid_file_list, + bytes_per_segment=params.bytes_per_segment, + training=False) valid_dl = torch.utils.data.DataLoader( - dataset=valid, + dataset=valid_data, batch_size=batch_size, num_workers=params.num_workers, drop_last=False)