diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py index edea3bdb9..5d47f128a 100755 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py @@ -501,6 +501,8 @@ def compute_loss( feature = batch["inputs"] assert feature.ndim == 3 + if params.use_fp16: + feature = feature.half() supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"] @@ -559,14 +561,13 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=False, - ) + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=False, + ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -680,14 +681,13 @@ def train_one_epoch( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - tokenizer=tokenizer, - model=model, - batch=batch, - is_training=True, - ) + loss, loss_info = compute_loss( + params=params, + tokenizer=tokenizer, + model=model, + batch=batch, + is_training=True, + ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info