minor fixes

This commit is contained in:
Fangjun Kuang 2024-10-30 21:25:31 +08:00
parent 256c446f06
commit 17d7174cd1

View File

@ -1121,11 +1121,18 @@ def train_one_epoch(
rank=0, rank=0,
) )
def is_grad_limit_enabled():
return (0 < params.limit_grad_start_batch <= params.batch_idx_train) and (
params.batch_idx_train % params.limit_grad_every_n_batch == 0
)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params)) set_batch_count(model, get_adjusted_batch_count(params))
params.batch_idx_train += 1 params.batch_idx_train += 1
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
@ -1135,9 +1142,7 @@ def train_one_epoch(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
model_prev=model_prev model_prev=model_prev if is_grad_limit_enabled() else None,
if 0 < params.limit_grad_start_batch < params.batch_idx_train
else None,
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
@ -1155,14 +1160,10 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if ( if is_grad_limit_enabled():
0 < params.limit_grad_start_batch <= params.batch_idx_train
and params.batch_idx_train % params.limit_grad_every_n_batch == 0
):
if model_prev is None: if model_prev is None:
model_prev = copy.deepcopy(model) model_prev = copy.deepcopy(model)
else: else:
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
update_model_prev(model_prev=model_prev, model=model, beta=beta) update_model_prev(model_prev=model_prev, model=model, beta=beta)
except Exception as e: except Exception as e: