From 9939c2b72d079dc939633f6d038475e9bd3a5c4a Mon Sep 17 00:00:00 2001 From: yfyeung Date: Sun, 11 May 2025 17:03:44 +0000 Subject: [PATCH] remove duplicated torch autocast --- .../ASR_LLM/zipformer_llm_zh/train.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py index edea3bdb9..5d47f128a 100755 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py @@ -501,6 +501,8 @@ def compute_loss( feature = batch["inputs"] assert feature.ndim == 3 + if params.use_fp16: + feature = feature.half() supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"] @@ -559,14 +561,13 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=False, - ) + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -680,14 +681,13 @@ def train_one_epoch( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=True, - ) + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info