mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 02:06:13 +00:00
add speech io dataset
This commit is contained in:
parent
8226b628f4
commit
dbe85c1f12
@ -206,10 +206,11 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-aishell",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to only use aishell1 dataset for training.",
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="aishell",
|
||||
choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
|
||||
help="The dataset to decode",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -540,7 +541,7 @@ def main():
|
||||
|
||||
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
start = params.epoch - params.avg + 1
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
@ -551,18 +552,17 @@ def main():
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
model.load_state_dict(average_checkpoints(filenames), strict=False)
|
||||
avg_checkpoint = average_checkpoints(filenames)
|
||||
model.load_state_dict(avg_checkpoint, strict=False)
|
||||
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
torch.save(avg_checkpoint, filename)
|
||||
else:
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
@ -584,11 +584,14 @@ def main():
|
||||
return False
|
||||
return True
|
||||
|
||||
if params.use_aishell:
|
||||
if params.dataset == "aishell":
|
||||
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
||||
else:
|
||||
# test_sets_cuts = multi_dataset.test_cuts()
|
||||
elif params.dataset == "speechio":
|
||||
test_sets_cuts = multi_dataset.speechio_test_cuts()
|
||||
elif params.dataaset == "wenetspeech_test_meeting":
|
||||
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
||||
else:
|
||||
test_sets_cuts = multi_dataset.test_cuts()
|
||||
|
||||
test_sets = test_sets_cuts.keys()
|
||||
test_dls = [
|
||||
|
@ -331,3 +331,25 @@ class MultiDataset:
|
||||
return {
|
||||
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
||||
}
|
||||
|
||||
def speechio_test_cuts(self) -> Dict[str, CutSet]:
|
||||
logging.info("About to get multidataset test cuts")
|
||||
start_index = 0
|
||||
end_index = 26
|
||||
dataset_parts = []
|
||||
for i in range(start_index, end_index + 1):
|
||||
idx = f"{i}".zfill(2)
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
|
||||
prefix = "speechio"
|
||||
suffix = "jsonl.gz"
|
||||
|
||||
results_dict = {}
|
||||
for partition in dataset_parts:
|
||||
path = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
|
||||
logging.info(f"Loading {path} set in lazy mode")
|
||||
test_cuts = load_manifest_lazy(self.fbank_dir / path)
|
||||
results_dict[partition] = test_cuts
|
||||
|
||||
return results_dict
|
@ -751,7 +751,6 @@ def run(rank, world_size, args):
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||
assert len(unexpected_keys) == 0, unexpected_keys
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user