mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Remove old code
This commit is contained in:
parent
6c3ab1e706
commit
85b6450a8a
@ -632,11 +632,6 @@ def compute_validation_loss(
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
@ -654,7 +649,7 @@ def train(
|
||||
rank: int = 0,
|
||||
batch_idx_offset: int = 0,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
"""Train the model until we have trained on the specified --num-tokens.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
@ -778,13 +773,13 @@ def train(
|
||||
if params.num_tokens_seen > params.num_tokens:
|
||||
break
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16:
|
||||
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
@ -795,14 +790,14 @@ def train(
|
||||
save_bad_model()
|
||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
if params.batch_idx_train % params.log_interval == 0:
|
||||
cur_lr = max(scheduler.get_last_lr())
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.num_tokens_seen / params.tokens_per_epoch:.3f}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], tokens: {tokens_seen} "
|
||||
f"batch {params.batch_idx_train}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], tokens: {params.num_tokens_seen} "
|
||||
f"lr: {cur_lr:.2e}, " +
|
||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
@ -824,7 +819,7 @@ def train(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -930,26 +925,26 @@ def run(rank, world_size, args):
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
|
||||
train = LmDataset(params.train_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment)
|
||||
train_data = LmDataset(params.train_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment)
|
||||
|
||||
params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress.
|
||||
params.tokens_per_epoch = train_data.num_tokens() # helps us figure out epoch progress.
|
||||
|
||||
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(
|
||||
dataset=train,
|
||||
dataset=train_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=params.num_workers,
|
||||
drop_last=True)
|
||||
|
||||
|
||||
valid = LmDataset(params.valid_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment,
|
||||
training=False)
|
||||
valid_data = LmDataset(params.valid_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment,
|
||||
training=False)
|
||||
|
||||
valid_dl = torch.utils.data.DataLoader(
|
||||
dataset=valid,
|
||||
dataset=valid_data,
|
||||
batch_size=batch_size,
|
||||
num_workers=params.num_workers,
|
||||
drop_last=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user