mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Update conformer.py for aishell4 (#484)
* update conformer.py for aishell4 * update conformer.py * add strict=False when model.load_state_dict
This commit is contained in:
parent
a8696b36fc
commit
3d2986b4c2
File diff suppressed because it is too large
Load Diff
1
egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py
Symbolic link
1
egs/aishell4/ASR/pruned_transducer_stateless5/conformer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
|
@ -523,7 +523,9 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
elif params.avg == 1:
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
@ -534,7 +536,9 @@ def main():
|
|||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(
|
||||||
@ -562,7 +566,8 @@ def main():
|
|||||||
filename_start=filename_start,
|
filename_start=filename_start,
|
||||||
filename_end=filename_end,
|
filename_end=filename_end,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
),
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert params.avg > 0, params.avg
|
assert params.avg > 0, params.avg
|
||||||
@ -580,7 +585,8 @@ def main():
|
|||||||
filename_start=filename_start,
|
filename_start=filename_start,
|
||||||
filename_end=filename_end,
|
filename_end=filename_end,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
),
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
@ -184,7 +184,9 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
elif params.avg == 1:
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
else:
|
else:
|
||||||
@ -195,7 +197,9 @@ def main():
|
|||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(
|
||||||
@ -223,7 +227,8 @@ def main():
|
|||||||
filename_start=filename_start,
|
filename_start=filename_start,
|
||||||
filename_end=filename_end,
|
filename_end=filename_end,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
),
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert params.avg > 0, params.avg
|
assert params.avg > 0, params.avg
|
||||||
@ -241,7 +246,8 @@ def main():
|
|||||||
filename_start=filename_start,
|
filename_start=filename_start,
|
||||||
filename_end=filename_end,
|
filename_end=filename_end,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
),
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user