fix train.py and decode.py

fix
This commit is contained in:
Your Name 2025-04-23 02:32:05 -07:00
parent 5ec95e5482
commit 3511b7db12
2 changed files with 22 additions and 22 deletions

View File

@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
@ -446,7 +446,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
@ -456,7 +456,7 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
@ -472,7 +472,7 @@ def save_results(
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
@ -495,9 +495,13 @@ def main():
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / f"{params.method}"
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
params.res_dir
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
)
logging.info("Decoding started")
@ -574,23 +578,20 @@ def main():
if params.avg > 1:
start = params.epoch - params.avg + 1
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
for epoch in range(start, params.epoch + 1)
]
avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(avg_checkpoint, filename)
# filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
# torch.save(avg_checkpoint, filename)
else:
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
map_location="cpu",
)
model.load_state_dict(checkpoint, strict=False)
@ -643,8 +644,7 @@ def main():
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
main()

View File

@ -523,7 +523,7 @@ def train_one_epoch(
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -764,7 +764,7 @@ def run(rank, world_size, args):
if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
@ -806,15 +806,15 @@ def run(rank, world_size, args):
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
client_state={},
exclude_frozen_parameters=True,
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True,
)
# save sampler state dict into checkpoint
@ -824,7 +824,7 @@ def run(rank, world_size, args):
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
)
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!")