Update conformer.py for aishell4 (#484)

* update conformer.py for aishell4

* update conformer.py

* add strict=False when model.load_state_dict
This commit is contained in:
Mingshuang Luo 2022-07-20 21:32:53 +08:00 committed by GitHub
parent a8696b36fc
commit 3d2986b4c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 1340 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py

View File

@ -523,7 +523,9 @@ def main():
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
@ -534,7 +536,9 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(
@ -562,7 +566,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
),
strict=False,
)
else:
assert params.avg > 0, params.avg
@ -580,7 +585,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
),
strict=False,
)
model.to(device)

View File

@ -184,7 +184,9 @@ def main():
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
@ -195,7 +197,9 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(
@ -223,7 +227,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
),
strict=False,
)
else:
assert params.avg > 0, params.avg
@ -241,7 +246,8 @@ def main():
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
),
strict=False,
)
model.eval()