diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py index 7f4d1bbdc..42f0735ad 100644 --- a/egs/wenetspeech/ASR/whisper/train.py +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -441,6 +441,9 @@ def compute_loss( assert feature.ndim == 3 feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) + # make sure feature T no more than 3000, otherwise cut it + if feature.shape[2] > 3000: + feature = feature[:, :, :3000] supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) @@ -604,6 +607,18 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) + if params.deepspeed: + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + client_state={}, + ) + if rank == 0: + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt", + tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}", + ) try: with torch.cuda.amp.autocast(enabled=params.use_fp16):