mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-11 11:02:29 +00:00
fix train.py and decode.py
fix
This commit is contained in:
parent
5ec95e5482
commit
3511b7db12
@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
|
|||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -446,7 +446,7 @@ def save_results(
|
|||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = (
|
recog_path = (
|
||||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
@ -456,7 +456,7 @@ def save_results(
|
|||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = (
|
errs_filename = (
|
||||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
@ -472,7 +472,7 @@ def save_results(
|
|||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tCER", file=f)
|
print("settings\tCER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
@ -495,9 +495,13 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
params.res_dir = params.exp_dir / f"{params.method}"
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
setup_logger(
|
setup_logger(
|
||||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
params.res_dir
|
||||||
|
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
@ -574,23 +578,20 @@ def main():
|
|||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
assert start >= 1, start
|
assert start >= 1, start
|
||||||
checkpoint = torch.load(
|
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
|
||||||
)
|
|
||||||
assert "model" not in checkpoint
|
|
||||||
# deepspeed converted checkpoint only contains model state_dict
|
# deepspeed converted checkpoint only contains model state_dict
|
||||||
filenames = [
|
filenames = [
|
||||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
|
||||||
for epoch in range(start, params.epoch + 1)
|
for epoch in range(start, params.epoch + 1)
|
||||||
]
|
]
|
||||||
avg_checkpoint = average_checkpoints(filenames)
|
avg_checkpoint = average_checkpoints(filenames)
|
||||||
model.load_state_dict(avg_checkpoint, strict=False)
|
model.load_state_dict(avg_checkpoint, strict=False)
|
||||||
|
|
||||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
# filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
torch.save(avg_checkpoint, filename)
|
# torch.save(avg_checkpoint, filename)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
|
||||||
|
map_location="cpu",
|
||||||
)
|
)
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
|
|
||||||
@ -643,8 +644,7 @@ def main():
|
|||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
main()
|
||||||
|
@ -523,7 +523,7 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -764,7 +764,7 @@ def run(rank, world_size, args):
|
|||||||
if params.sampler_state_dict_path:
|
if params.sampler_state_dict_path:
|
||||||
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
sampler_state_dict = torch.load(params.sampler_state_dict_path)
|
||||||
sampler_state_dict["max_duration"] = params.max_duration
|
sampler_state_dict["max_duration"] = params.max_duration
|
||||||
# TODO: load sampler state dict
|
|
||||||
train_dl = data_module.train_dataloaders(
|
train_dl = data_module.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
@ -806,15 +806,15 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model.save_checkpoint(
|
model.save_checkpoint(
|
||||||
save_dir=params.exp_dir,
|
save_dir=params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"zero-epoch-{params.cur_epoch}",
|
||||||
client_state={},
|
client_state={},
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir,
|
params.exp_dir,
|
||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}",
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
tag=f"zero-epoch-{params.cur_epoch}",
|
||||||
exclude_frozen_parameters=True,
|
exclude_frozen_parameters=True,
|
||||||
)
|
)
|
||||||
# save sampler state dict into checkpoint
|
# save sampler state dict into checkpoint
|
||||||
@ -824,7 +824,7 @@ def run(rank, world_size, args):
|
|||||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}")
|
os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user