add speech io dataset

This commit is contained in:
root 2024-06-13 09:08:59 +00:00 committed by Yuekai Zhang
parent 8226b628f4
commit dbe85c1f12
3 changed files with 40 additions and 16 deletions

View File

@ -206,10 +206,11 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--use-aishell", "--dataset",
type=str2bool, type=str,
default=True, default="aishell",
help="Whether to only use aishell1 dataset for training.", choices=["aishell", "speechio", "wenetspeech_test_meeting", "multi_hans_zh"],
help="The dataset to decode",
) )
add_model_arguments(parser) add_model_arguments(parser)
@ -540,7 +541,7 @@ def main():
if params.avg > 1: if params.avg > 1:
start = params.epoch - params.avg start = params.epoch - params.avg + 1
assert start >= 1, start assert start >= 1, start
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
@ -551,18 +552,17 @@ def main():
f"{params.exp_dir}/epoch-{epoch}.pt" f"{params.exp_dir}/epoch-{epoch}.pt"
for epoch in range(start, params.epoch + 1) for epoch in range(start, params.epoch + 1)
] ]
model.load_state_dict(average_checkpoints(filenames), strict=False) avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename) torch.save(avg_checkpoint, filename)
else: else:
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
) )
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(checkpoint, strict=False)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
@ -584,11 +584,14 @@ def main():
return False return False
return True return True
if params.use_aishell: if params.dataset == "aishell":
test_sets_cuts = multi_dataset.aishell_test_cuts() test_sets_cuts = multi_dataset.aishell_test_cuts()
else: elif params.dataset == "speechio":
# test_sets_cuts = multi_dataset.test_cuts() test_sets_cuts = multi_dataset.speechio_test_cuts()
elif params.dataaset == "wenetspeech_test_meeting":
test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts() test_sets_cuts = multi_dataset.wenetspeech_test_meeting_cuts()
else:
test_sets_cuts = multi_dataset.test_cuts()
test_sets = test_sets_cuts.keys() test_sets = test_sets_cuts.keys()
test_dls = [ test_dls = [

View File

@ -331,3 +331,25 @@ class MultiDataset:
return { return {
"wenetspeech-meeting_test": wenetspeech_test_meeting_cuts, "wenetspeech-meeting_test": wenetspeech_test_meeting_cuts,
} }
def speechio_test_cuts(self) -> Dict[str, CutSet]:
logging.info("About to get multidataset test cuts")
start_index = 0
end_index = 26
dataset_parts = []
for i in range(start_index, end_index + 1):
idx = f"{i}".zfill(2)
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
prefix = "speechio"
suffix = "jsonl.gz"
results_dict = {}
for partition in dataset_parts:
path = f"{prefix}_cuts_{partition}.{suffix}"
logging.info(f"Loading {path} set in lazy mode")
test_cuts = load_manifest_lazy(self.fbank_dir / path)
results_dict[partition] = test_cuts
return results_dict

View File

@ -751,7 +751,6 @@ def run(rank, world_size, args):
if params.pretrained_model_path: if params.pretrained_model_path:
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu") checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False)
assert len(unexpected_keys) == 0, unexpected_keys
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")