mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 01:24:19 +00:00
minor fixes
This commit is contained in:
parent
8e1d8c9316
commit
235bb0537a
@ -770,7 +770,7 @@ def main():
|
|||||||
|
|
||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
aishell = AiShell2AsrDataModule(args)
|
aishell2 = AiShell2AsrDataModule(args)
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
@ -780,13 +780,13 @@ def main():
|
|||||||
)
|
)
|
||||||
return T > 0
|
return T > 0
|
||||||
|
|
||||||
dev_cuts = aishell.valid_cuts()
|
dev_cuts = aishell2.valid_cuts()
|
||||||
dev_cuts = dev_cuts.filter(remove_short_utt)
|
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||||
dev_dl = aishell.valid_dataloaders(dev_cuts)
|
dev_dl = aishell2.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
test_cuts = aishell.test_meeting_cuts()
|
test_cuts = aishell2.test_cuts()
|
||||||
test_cuts = test_cuts.filter(remove_short_utt)
|
test_cuts = test_cuts.filter(remove_short_utt)
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = aishell2.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_sets = ["dev", "test"]
|
test_sets = ["dev", "test"]
|
||||||
test_dls = [dev_dl, test_dl]
|
test_dls = [dev_dl, test_dl]
|
||||||
|
@ -241,7 +241,7 @@ def main():
|
|||||||
# we need cut ids to display recognition results.
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
|
|
||||||
aishell = AiShell2AsrDataModule(args)
|
aishell2 = AiShell2AsrDataModule(args)
|
||||||
|
|
||||||
def remove_short_utt(c: Cut):
|
def remove_short_utt(c: Cut):
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
@ -251,13 +251,13 @@ def main():
|
|||||||
)
|
)
|
||||||
return T > 0
|
return T > 0
|
||||||
|
|
||||||
dev_cuts = aishell.valid_cuts()
|
dev_cuts = aishell2.valid_cuts()
|
||||||
dev_cuts = dev_cuts.filter(remove_short_utt)
|
dev_cuts = dev_cuts.filter(remove_short_utt)
|
||||||
dev_dl = aishell.valid_dataloaders(dev_cuts)
|
dev_dl = aishell2.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
test_cuts = aishell.test_net_cuts()
|
test_cuts = aishell2.test_net_cuts()
|
||||||
test_cuts = test_cuts.filter(remove_short_utt)
|
test_cuts = test_cuts.filter(remove_short_utt)
|
||||||
test_dl = aishell.test_dataloaders(test_cuts)
|
test_dl = aishell2.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
test_sets = ["dev", "test"]
|
test_sets = ["dev", "test"]
|
||||||
test_dl = [dev_dl, test_dl]
|
test_dl = [dev_dl, test_dl]
|
||||||
|
@ -855,10 +855,10 @@ def main():
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
aishell = AiShell2AsrDataModule(args)
|
aishell2 = AiShell2AsrDataModule(args)
|
||||||
|
|
||||||
dev_cuts = aishell.valid_cuts()
|
dev_cuts = aishell2.valid_cuts()
|
||||||
test_cuts = aishell.test_cuts()
|
test_cuts = aishell2.test_cuts()
|
||||||
|
|
||||||
test_sets = ["dev", "test"]
|
test_sets = ["dev", "test"]
|
||||||
test_cuts = [dev_cuts, test_cuts]
|
test_cuts = [dev_cuts, test_cuts]
|
||||||
|
@ -1135,10 +1135,10 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
aishell = AiShell2AsrDataModule(args)
|
aishell2 = AiShell2AsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = aishell.train_cuts()
|
train_cuts = aishell2.train_cuts()
|
||||||
valid_cuts = aishell.valid_cuts()
|
valid_cuts = aishell2.valid_cuts()
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 15 seconds
|
# Keep only utterances with duration between 1 second and 15 seconds
|
||||||
@ -1186,11 +1186,11 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
train_dl = aishell.train_dataloaders(
|
train_dl = aishell2.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
valid_dl = aishell.valid_dataloaders(valid_cuts)
|
valid_dl = aishell2.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if False and not params.print_diagnostics:
|
if False and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user