mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
minor fixes
This commit is contained in:
parent
256c446f06
commit
17d7174cd1
@ -1121,11 +1121,18 @@ def train_one_epoch(
|
||||
rank=0,
|
||||
)
|
||||
|
||||
def is_grad_limit_enabled():
|
||||
return (0 < params.limit_grad_start_batch <= params.batch_idx_train) and (
|
||||
params.batch_idx_train % params.limit_grad_every_n_batch == 0
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
|
||||
params.batch_idx_train += 1
|
||||
|
||||
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
try:
|
||||
@ -1135,9 +1142,7 @@ def train_one_epoch(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
model_prev=model_prev
|
||||
if 0 < params.limit_grad_start_batch < params.batch_idx_train
|
||||
else None,
|
||||
model_prev=model_prev if is_grad_limit_enabled() else None,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
@ -1155,14 +1160,10 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if (
|
||||
0 < params.limit_grad_start_batch <= params.batch_idx_train
|
||||
and params.batch_idx_train % params.limit_grad_every_n_batch == 0
|
||||
):
|
||||
if is_grad_limit_enabled():
|
||||
if model_prev is None:
|
||||
model_prev = copy.deepcopy(model)
|
||||
else:
|
||||
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
|
||||
update_model_prev(model_prev=model_prev, model=model, beta=beta)
|
||||
|
||||
except Exception as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user