mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
fix train.py and decode.py
fix
This commit is contained in:
parent
5ec95e5482
commit
3511b7db12
@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -446,7 +446,7 @@ def save_results(
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
@ -456,7 +456,7 @@ def save_results(
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
@ -472,7 +472,7 @@ def save_results(
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tCER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
@ -495,9 +495,13 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / f"{params.method}"
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||
params.res_dir
|
||||
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
|
||||
)
|
||||
|
||||
logging.info("Decoding started")
|
||||
@ -574,23 +578,20 @@ def main():
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg + 1
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
assert "model" not in checkpoint
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
avg_checkpoint = average_checkpoints(filenames)
|
||||
model.load_state_dict(avg_checkpoint, strict=False)
|
||||
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(avg_checkpoint, filename)
|
||||
# filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
# torch.save(avg_checkpoint, filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
|
||||
map_location="cpu",
|
||||
)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
@ -643,8 +644,7 @@ def main():
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
main()
|
||||
|
@ -523,7 +523,7 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -764,7 +764,7 @@ def run(rank, world_size, args):
|
||||
if params.sampler_state_dict_path:
|
||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||
sampler_state_dict["max_duration"] = params.max_duration
|
||||
# TODO: load sampler state dict
|
||||
|
||||
train_dl = data_module.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
@ -806,15 +806,15 @@ def run(rank, world_size, args):
|
||||
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
tag=f"zero-epoch-{params.cur_epoch}",
|
||||
client_state={},
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}",
|
||||
tag=f"zero-epoch-{params.cur_epoch}",
|
||||
exclude_frozen_parameters=True,
|
||||
)
|
||||
# save sampler state dict into checkpoint
|
||||
@ -824,7 +824,7 @@ def run(rank, world_size, args):
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
||||
)
|
||||
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
||||
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user