mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
add speech io dataset
This commit is contained in:
parent
8226b628f4
commit
dbe85c1f12
@ -206,10 +206,11 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-aishell",
|
"--dataset",
|
||||||
type=str2bool,
|
type=str,
|
||||||
default=True,
|
default="aishell",
|
||||||
help="Whether to only use aishell1 dataset for training.",
|
choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
|
||||||
|
help="The dataset to decode",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -540,7 +541,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg + 1
|
||||||
assert start >= 1, start
|
assert start >= 1, start
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||||
@ -551,18 +552,17 @@ def main():
|
|||||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||||
for epoch in range(start, params.epoch + 1)
|
for epoch in range(start, params.epoch + 1)
|
||||||
]
|
]
|
||||||
model.load_state_dict(average_checkpoints(filenames), strict=False)
|
avg_checkpoint = average_checkpoints(filenames)
|
||||||
|
model.load_state_dict(avg_checkpoint, strict=False)
|
||||||
|
|
||||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
torch.save(model.state_dict(), filename)
|
torch.save(avg_checkpoint, filename)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(
|
checkpoint = torch.load(
|
||||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||||
)
|
)
|
||||||
if "model" not in checkpoint:
|
model.load_state_dict(checkpoint, strict=False)
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
|
||||||
else:
|
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
@ -584,11 +584,14 @@ def main():
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if params.use_aishell:
|
if params.dataset == "aishell":
|
||||||
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
test_sets_cuts = multi_dataset.aishell_test_cuts()
|
||||||
else:
|
elif params.dataset == "speechio":
|
||||||
# test_sets_cuts = multi_dataset.test_cuts()
|
test_sets_cuts = multi_dataset.speechio_test_cuts()
|
||||||
|
elif params.dataaset == "wenetspeech_test_meeting":
|
||||||
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
|
||||||
|
else:
|
||||||
|
test_sets_cuts = multi_dataset.test_cuts()
|
||||||
|
|
||||||
test_sets = test_sets_cuts.keys()
|
test_sets = test_sets_cuts.keys()
|
||||||
test_dls = [
|
test_dls = [
|
||||||
|
@ -331,3 +331,25 @@ class MultiDataset:
|
|||||||
return {
|
return {
|
||||||
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def speechio_test_cuts(self) -> Dict[str, CutSet]:
|
||||||
|
logging.info("About to get multidataset test cuts")
|
||||||
|
start_index = 0
|
||||||
|
end_index = 26
|
||||||
|
dataset_parts = []
|
||||||
|
for i in range(start_index, end_index + 1):
|
||||||
|
idx = f"{i}".zfill(2)
|
||||||
|
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||||
|
|
||||||
|
prefix = "speechio"
|
||||||
|
suffix = "jsonl.gz"
|
||||||
|
|
||||||
|
results_dict = {}
|
||||||
|
for partition in dataset_parts:
|
||||||
|
path = f"{prefix}_cuts_{partition}.{suffix}"
|
||||||
|
|
||||||
|
logging.info(f"Loading {path} set in lazy mode")
|
||||||
|
test_cuts = load_manifest_lazy(self.fbank_dir / path)
|
||||||
|
results_dict[partition] = test_cuts
|
||||||
|
|
||||||
|
return results_dict
|
@ -751,7 +751,6 @@ def run(rank, world_size, args):
|
|||||||
if params.pretrained_model_path:
|
if params.pretrained_model_path:
|
||||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||||
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
|
||||||
assert len(unexpected_keys) == 0, unexpected_keys
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user