Update train.py

This commit is contained in:
zr_jin 2024-12-09 20:04:20 +08:00
parent 32b7a449e7
commit c19d2b43af

View File

@ -488,9 +488,10 @@ def train_one_epoch(
loss = sum(losses.values())
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
loss_info = MetricsTracker()
loss_info["samples"] = batch_size