mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
minor fixes
This commit is contained in:
parent
256c446f06
commit
17d7174cd1
@ -1121,11 +1121,18 @@ def train_one_epoch(
|
|||||||
rank=0,
|
rank=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_grad_limit_enabled():
|
||||||
|
return (0 < params.limit_grad_start_batch <= params.batch_idx_train) and (
|
||||||
|
params.batch_idx_train % params.limit_grad_every_n_batch == 0
|
||||||
|
)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx % 10 == 0:
|
if batch_idx % 10 == 0:
|
||||||
set_batch_count(model, get_adjusted_batch_count(params))
|
set_batch_count(model, get_adjusted_batch_count(params))
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
|
|
||||||
|
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1135,9 +1142,7 @@ def train_one_epoch(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
model_prev=model_prev
|
model_prev=model_prev if is_grad_limit_enabled() else None,
|
||||||
if 0 < params.limit_grad_start_batch < params.batch_idx_train
|
|
||||||
else None,
|
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
@ -1155,14 +1160,10 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if (
|
if is_grad_limit_enabled():
|
||||||
0 < params.limit_grad_start_batch <= params.batch_idx_train
|
|
||||||
and params.batch_idx_train % params.limit_grad_every_n_batch == 0
|
|
||||||
):
|
|
||||||
if model_prev is None:
|
if model_prev is None:
|
||||||
model_prev = copy.deepcopy(model)
|
model_prev = copy.deepcopy(model)
|
||||||
else:
|
else:
|
||||||
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
|
|
||||||
update_model_prev(model_prev=model_prev, model=model, beta=beta)
|
update_model_prev(model_prev=model_prev, model=model, beta=beta)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user