diff --git a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py index 2ace5c532..eaae4a33e 100755 --- a/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/zipformer_llm_zh/train.py @@ -500,10 +500,10 @@ def compute_loss( device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 + features = batch["inputs"] + assert features.ndim == 3 if params.use_fp16: - feature = feature.half() + features = features.half() supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"] @@ -526,7 +526,7 @@ def compute_loss( with torch.set_grad_enabled(is_training): model_outputs, acc = model( - fbank=feature.to(device), + fbank=features.to(device), fbank_lens=feature_lens.to(device), input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), @@ -663,30 +663,6 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - if batch_idx != 0: - model.save_checkpoint( - save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - client_state={}, - exclude_frozen_parameters=True, - ) - - if rank == 0: - convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", - tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", - exclude_frozen_parameters=True, - ) - # save sampler state dict into checkpoint - sampler_state_dict = train_dl.sampler.state_dict() - torch.save( - sampler_state_dict, - f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}-sampler.pt", - ) - os.system( - f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" - ) try: loss, loss_info = compute_loss( params=params,