Merge branch 'conformer_ctc2' of github.com:WayneWiser/icefall into conformer_ctc2

This commit is contained in:
Quandwang 2022-07-21 13:35:35 +08:00
commit 1ad9862b39
3 changed files with 21 additions and 1340 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py

View File

@ -523,7 +523,9 @@ def main():
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
@ -534,7 +536,9 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(
@ -562,7 +566,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
),
strict=False,
)
else:
assert params.avg > 0, params.avg
@ -580,7 +585,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
),
strict=False,
)
model.to(device)

View File

@ -184,7 +184,9 @@ def main():
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
@ -195,7 +197,9 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(
@ -223,7 +227,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
),
strict=False,
)
else:
assert params.avg > 0, params.avg
@ -241,7 +246,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
),
strict=False,
)
model.eval()