From 3511b7db12cf71591bbb6abae1ec60589264748a Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 23 Apr 2025 02:32:05 -0700 Subject: [PATCH] fix train.py and decode.py fix --- .../ASR_LLM/whisper_llm_zh/decode.py | 32 +++++++++---------- .../ASR_LLM/whisper_llm_zh/train.py | 12 +++---- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py index 882ce4fbf..14159af58 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/decode.py @@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN from transformers import AutoModelForCausalLM, AutoTokenizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint +from icefall.checkpoint import load_checkpoint from icefall.env import get_env_info from icefall.utils import ( AttributeDict, @@ -446,7 +446,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -456,7 +456,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -472,7 +472,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" + errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: @@ -495,9 +495,13 @@ def main(): params = get_params() params.update(vars(args)) + + params.res_dir = params.exp_dir / f"{params.method}" + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" setup_logger( - f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + params.res_dir + / f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}" ) logging.info("Decoding started") @@ -574,23 +578,20 @@ def main(): if params.avg > 1: start = params.epoch - params.avg + 1 assert start >= 1, start - checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" - ) - assert "model" not in checkpoint # deepspeed converted checkpoint only contains model state_dict filenames = [ - f"{params.exp_dir}/epoch-{epoch}.pt" + f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin" for epoch in range(start, params.epoch + 1) ] avg_checkpoint = average_checkpoints(filenames) model.load_state_dict(avg_checkpoint, strict=False) - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(avg_checkpoint, filename) + # filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + # torch.save(avg_checkpoint, filename) else: checkpoint = torch.load( - f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin", + map_location="cpu", ) model.load_state_dict(checkpoint, strict=False) @@ -643,8 +644,7 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 5f224c984..239080014 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -523,7 +523,7 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -764,7 +764,7 @@ def run(rank, world_size, args): if params.sampler_state_dict_path: sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict["max_duration"] = params.max_duration - # TODO: load sampler state dict + train_dl = data_module.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) @@ -806,15 +806,15 @@ def run(rank, world_size, args): model.save_checkpoint( save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", client_state={}, exclude_frozen_parameters=True, ) if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( params.exp_dir, - f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}", + f"{params.exp_dir}/epoch-{params.cur_epoch}", + tag=f"zero-epoch-{params.cur_epoch}", exclude_frozen_parameters=True, ) # save sampler state dict into checkpoint @@ -824,7 +824,7 @@ def run(rank, world_size, args): f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", ) - os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") + os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}") logging.info("Done!")