remove duplicated torch autocast

This commit is contained in:
yfyeung 2025-05-11 17:03:44 +00:00
parent 5fbeed9f96
commit 9939c2b72d

View File

@ -501,6 +501,8 @@ def compute_loss(
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
if params.use_fp16:
feature = feature.half()
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"] feature_lens = supervisions["num_frames"]
@ -559,14 +561,13 @@ def compute_validation_loss(
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss(
loss, loss_info = compute_loss( params=params,
params=params, tokenizer=tokenizer,
tokenizer=tokenizer, model=model,
model=model, batch=batch,
batch=batch, is_training=False,
is_training=False, )
)
assert loss.requires_grad is False assert loss.requires_grad is False
tot_loss = tot_loss + loss_info tot_loss = tot_loss + loss_info
@ -680,14 +681,13 @@ def train_one_epoch(
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss(
loss, loss_info = compute_loss( params=params,
params=params, tokenizer=tokenizer,
tokenizer=tokenizer, model=model,
model=model, batch=batch,
batch=batch, is_training=True,
is_training=True, )
)
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info