minor fixes

This commit is contained in:
Fangjun Kuang 2024-10-30 21:25:31 +08:00
parent 256c446f06
commit 17d7174cd1

View File

@ -1121,11 +1121,18 @@ def train_one_epoch(
rank=0,
)
def is_grad_limit_enabled():
return (0 < params.limit_grad_start_batch <= params.batch_idx_train) and (
params.batch_idx_train % params.limit_grad_every_n_batch == 0
)
for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
params.batch_idx_train += 1
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
batch_size = len(batch["supervisions"]["text"])
try:
@ -1135,9 +1142,7 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
model_prev=model_prev
if 0 < params.limit_grad_start_batch < params.batch_idx_train
else None,
model_prev=model_prev if is_grad_limit_enabled() else None,
sp=sp,
batch=batch,
is_training=True,
@ -1155,14 +1160,10 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if (
0 < params.limit_grad_start_batch <= params.batch_idx_train
and params.batch_idx_train % params.limit_grad_every_n_batch == 0
):
if is_grad_limit_enabled():
if model_prev is None:
model_prev = copy.deepcopy(model)
else:
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
update_model_prev(model_prev=model_prev, model=model, beta=beta)
except Exception as e: