mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Remove old code
This commit is contained in:
parent
6c3ab1e706
commit
85b6450a8a
@ -632,11 +632,6 @@ def compute_validation_loss(
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
tot_loss.reduce(loss.device)
|
tot_loss.reduce(loss.device)
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
|
||||||
if loss_value < params.best_valid_loss:
|
|
||||||
params.best_valid_epoch = params.cur_epoch
|
|
||||||
params.best_valid_loss = loss_value
|
|
||||||
|
|
||||||
return tot_loss
|
return tot_loss
|
||||||
|
|
||||||
|
|
||||||
@ -654,7 +649,7 @@ def train(
|
|||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
batch_idx_offset: int = 0,
|
batch_idx_offset: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Train the model for one epoch.
|
"""Train the model until we have trained on the specified --num-tokens.
|
||||||
|
|
||||||
The training loss from the mean of all frames is saved in
|
The training loss from the mean of all frames is saved in
|
||||||
`params.train_loss`. It runs the validation process every
|
`params.train_loss`. It runs the validation process every
|
||||||
@ -778,13 +773,13 @@ def train(
|
|||||||
if params.num_tokens_seen > params.num_tokens:
|
if params.num_tokens_seen > params.num_tokens:
|
||||||
break
|
break
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
if params.batch_idx_train % 100 == 0 and params.use_fp16:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0):
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
if not saved_bad_model:
|
if not saved_bad_model:
|
||||||
@ -795,14 +790,14 @@ def train(
|
|||||||
save_bad_model()
|
save_bad_model()
|
||||||
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}")
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if params.batch_idx_train % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.num_tokens_seen / params.tokens_per_epoch:.3f}, "
|
f"Epoch {params.num_tokens_seen / params.tokens_per_epoch:.3f}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {params.batch_idx_train}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], tokens: {tokens_seen} "
|
f"tot_loss[{tot_loss}], tokens: {params.num_tokens_seen} "
|
||||||
f"lr: {cur_lr:.2e}, " +
|
f"lr: {cur_lr:.2e}, " +
|
||||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
)
|
)
|
||||||
@ -824,7 +819,7 @@ def train(
|
|||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if params.batch_idx_train % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -930,26 +925,26 @@ def run(rank, world_size, args):
|
|||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
|
|
||||||
train = LmDataset(params.train_file_list,
|
train_data = LmDataset(params.train_file_list,
|
||||||
bytes_per_segment=params.bytes_per_segment)
|
bytes_per_segment=params.bytes_per_segment)
|
||||||
|
|
||||||
params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress.
|
params.tokens_per_epoch = train_data.num_tokens() # helps us figure out epoch progress.
|
||||||
|
|
||||||
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
|
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
|
||||||
|
|
||||||
train_dl = torch.utils.data.DataLoader(
|
train_dl = torch.utils.data.DataLoader(
|
||||||
dataset=train,
|
dataset=train_data,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=params.num_workers,
|
num_workers=params.num_workers,
|
||||||
drop_last=True)
|
drop_last=True)
|
||||||
|
|
||||||
|
|
||||||
valid = LmDataset(params.valid_file_list,
|
valid_data = LmDataset(params.valid_file_list,
|
||||||
bytes_per_segment=params.bytes_per_segment,
|
bytes_per_segment=params.bytes_per_segment,
|
||||||
training=False)
|
training=False)
|
||||||
|
|
||||||
valid_dl = torch.utils.data.DataLoader(
|
valid_dl = torch.utils.data.DataLoader(
|
||||||
dataset=valid,
|
dataset=valid_data,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=params.num_workers,
|
num_workers=params.num_workers,
|
||||||
drop_last=False)
|
drop_last=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user