minor fixes

This commit is contained in:
jinzr 2023-08-13 01:14:51 +08:00
parent 8e1d8c9316
commit 235bb0537a
4 changed files with 18 additions and 18 deletions

View File

@ -770,7 +770,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
aishell = AiShell2AsrDataModule(args) aishell2 = AiShell2AsrDataModule(args)
def remove_short_utt(c: Cut): def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2 T = ((c.num_frames - 7) // 2 + 1) // 2
@ -780,13 +780,13 @@ def main():
) )
return T > 0 return T > 0
dev_cuts = aishell.valid_cuts() dev_cuts = aishell2.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt) dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = aishell.valid_dataloaders(dev_cuts) dev_dl = aishell2.valid_dataloaders(dev_cuts)
test_cuts = aishell.test_meeting_cuts() test_cuts = aishell2.test_cuts()
test_cuts = test_cuts.filter(remove_short_utt) test_cuts = test_cuts.filter(remove_short_utt)
test_dl = aishell.test_dataloaders(test_cuts) test_dl = aishell2.test_dataloaders(test_cuts)
test_sets = ["dev", "test"] test_sets = ["dev", "test"]
test_dls = [dev_dl, test_dl] test_dls = [dev_dl, test_dl]

View File

@ -241,7 +241,7 @@ def main():
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
aishell = AiShell2AsrDataModule(args) aishell2 = AiShell2AsrDataModule(args)
def remove_short_utt(c: Cut): def remove_short_utt(c: Cut):
T = ((c.num_frames - 7) // 2 + 1) // 2 T = ((c.num_frames - 7) // 2 + 1) // 2
@ -251,13 +251,13 @@ def main():
) )
return T > 0 return T > 0
dev_cuts = aishell.valid_cuts() dev_cuts = aishell2.valid_cuts()
dev_cuts = dev_cuts.filter(remove_short_utt) dev_cuts = dev_cuts.filter(remove_short_utt)
dev_dl = aishell.valid_dataloaders(dev_cuts) dev_dl = aishell2.valid_dataloaders(dev_cuts)
test_cuts = aishell.test_net_cuts() test_cuts = aishell2.test_net_cuts()
test_cuts = test_cuts.filter(remove_short_utt) test_cuts = test_cuts.filter(remove_short_utt)
test_dl = aishell.test_dataloaders(test_cuts) test_dl = aishell2.test_dataloaders(test_cuts)
test_sets = ["dev", "test"] test_sets = ["dev", "test"]
test_dl = [dev_dl, test_dl] test_dl = [dev_dl, test_dl]

View File

@ -855,10 +855,10 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
aishell = AiShell2AsrDataModule(args) aishell2 = AiShell2AsrDataModule(args)
dev_cuts = aishell.valid_cuts() dev_cuts = aishell2.valid_cuts()
test_cuts = aishell.test_cuts() test_cuts = aishell2.test_cuts()
test_sets = ["dev", "test"] test_sets = ["dev", "test"]
test_cuts = [dev_cuts, test_cuts] test_cuts = [dev_cuts, test_cuts]

View File

@ -1135,10 +1135,10 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
aishell = AiShell2AsrDataModule(args) aishell2 = AiShell2AsrDataModule(args)
train_cuts = aishell.train_cuts() train_cuts = aishell2.train_cuts()
valid_cuts = aishell.valid_cuts() valid_cuts = aishell2.valid_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 15 seconds # Keep only utterances with duration between 1 second and 15 seconds
@ -1186,11 +1186,11 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = aishell.train_dataloaders( train_dl = aishell2.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
valid_dl = aishell.valid_dataloaders(valid_cuts) valid_dl = aishell2.valid_dataloaders(valid_cuts)
if False and not params.print_diagnostics: if False and not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(