Remove old code

This commit is contained in:
Daniel Povey 2023-06-19 07:45:57 +08:00
parent 6c3ab1e706
commit 85b6450a8a

View File

@ -632,11 +632,6 @@ def compute_validation_loss(
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
@ -654,7 +649,7 @@ def train(
rank: int = 0,
batch_idx_offset: int = 0,
) -> None:
"""Train the model for one epoch.
"""Train the model until we have trained on the specified --num-tokens.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
@ -778,13 +773,13 @@ def train(
if params.num_tokens_seen > params.num_tokens:
break
if batch_idx % 100 == 0 and params.use_fp16:
if params.batch_idx_train % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
@ -795,14 +790,14 @@ def train(
save_bad_model()
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
if batch_idx % params.log_interval == 0:
if params.batch_idx_train % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.num_tokens_seen / params.tokens_per_epoch:.3f}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], tokens: {tokens_seen} "
f"batch {params.batch_idx_train}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], tokens: {params.num_tokens_seen} "
f"lr: {cur_lr:.2e}, " +
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
@ -824,7 +819,7 @@ def train(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -930,26 +925,26 @@ def run(rank, world_size, args):
register_inf_check_hooks(model)
train = LmDataset(params.train_file_list,
bytes_per_segment=params.bytes_per_segment)
train_data = LmDataset(params.train_file_list,
bytes_per_segment=params.bytes_per_segment)
params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress.
params.tokens_per_epoch = train_data.num_tokens() # helps us figure out epoch progress.
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
train_dl = torch.utils.data.DataLoader(
dataset=train,
dataset=train_data,
batch_size=batch_size,
num_workers=params.num_workers,
drop_last=True)
valid = LmDataset(params.valid_file_list,
bytes_per_segment=params.bytes_per_segment,
training=False)
valid_data = LmDataset(params.valid_file_list,
bytes_per_segment=params.bytes_per_segment,
training=False)
valid_dl = torch.utils.data.DataLoader(
dataset=valid,
dataset=valid_data,
batch_size=batch_size,
num_workers=params.num_workers,
drop_last=False)