From 235bb0537aa04877c2476b6eb25fdd00be1b61e6 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Sun, 13 Aug 2023 01:14:51 +0800 Subject: [PATCH] minor fixes --- egs/aishell2/ASR/zipformer/decode.py | 10 +++++----- egs/aishell2/ASR/zipformer/onnx_decode.py | 10 +++++----- egs/aishell2/ASR/zipformer/streaming_decode.py | 6 +++--- egs/aishell2/ASR/zipformer/train.py | 10 +++++----- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/egs/aishell2/ASR/zipformer/decode.py b/egs/aishell2/ASR/zipformer/decode.py index 2f0984cdc..2f0edfe57 100755 --- a/egs/aishell2/ASR/zipformer/decode.py +++ b/egs/aishell2/ASR/zipformer/decode.py @@ -770,7 +770,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - aishell = AiShell2AsrDataModule(args) + aishell2 = AiShell2AsrDataModule(args) def remove_short_utt(c: Cut): T = ((c.num_frames - 7) // 2 + 1) // 2 @@ -780,13 +780,13 @@ def main(): ) return T > 0 - dev_cuts = aishell.valid_cuts() + dev_cuts = aishell2.valid_cuts() dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = aishell.valid_dataloaders(dev_cuts) + dev_dl = aishell2.valid_dataloaders(dev_cuts) - test_cuts = aishell.test_meeting_cuts() + test_cuts = aishell2.test_cuts() test_cuts = test_cuts.filter(remove_short_utt) - test_dl = aishell.test_dataloaders(test_cuts) + test_dl = aishell2.test_dataloaders(test_cuts) test_sets = ["dev", "test"] test_dls = [dev_dl, test_dl] diff --git a/egs/aishell2/ASR/zipformer/onnx_decode.py b/egs/aishell2/ASR/zipformer/onnx_decode.py index cc32e3962..d518f4e86 100755 --- a/egs/aishell2/ASR/zipformer/onnx_decode.py +++ b/egs/aishell2/ASR/zipformer/onnx_decode.py @@ -241,7 +241,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - aishell = AiShell2AsrDataModule(args) + aishell2 = AiShell2AsrDataModule(args) def remove_short_utt(c: Cut): T = ((c.num_frames - 7) // 2 + 1) // 2 @@ -251,13 +251,13 @@ def main(): ) return T > 0 - dev_cuts = aishell.valid_cuts() + dev_cuts = aishell2.valid_cuts() dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = aishell.valid_dataloaders(dev_cuts) + dev_dl = aishell2.valid_dataloaders(dev_cuts) - test_cuts = aishell.test_net_cuts() + test_cuts = aishell2.test_net_cuts() test_cuts = test_cuts.filter(remove_short_utt) - test_dl = aishell.test_dataloaders(test_cuts) + test_dl = aishell2.test_dataloaders(test_cuts) test_sets = ["dev", "test"] test_dl = [dev_dl, test_dl] diff --git a/egs/aishell2/ASR/zipformer/streaming_decode.py b/egs/aishell2/ASR/zipformer/streaming_decode.py index 547b9856b..dfb0123a5 100755 --- a/egs/aishell2/ASR/zipformer/streaming_decode.py +++ b/egs/aishell2/ASR/zipformer/streaming_decode.py @@ -855,10 +855,10 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - aishell = AiShell2AsrDataModule(args) + aishell2 = AiShell2AsrDataModule(args) - dev_cuts = aishell.valid_cuts() - test_cuts = aishell.test_cuts() + dev_cuts = aishell2.valid_cuts() + test_cuts = aishell2.test_cuts() test_sets = ["dev", "test"] test_cuts = [dev_cuts, test_cuts] diff --git a/egs/aishell2/ASR/zipformer/train.py b/egs/aishell2/ASR/zipformer/train.py index 3cb517d9d..1c5869ef0 100755 --- a/egs/aishell2/ASR/zipformer/train.py +++ b/egs/aishell2/ASR/zipformer/train.py @@ -1135,10 +1135,10 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - aishell = AiShell2AsrDataModule(args) + aishell2 = AiShell2AsrDataModule(args) - train_cuts = aishell.train_cuts() - valid_cuts = aishell.valid_cuts() + train_cuts = aishell2.train_cuts() + valid_cuts = aishell2.valid_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 15 seconds @@ -1186,11 +1186,11 @@ def run(rank, world_size, args): else: sampler_state_dict = None - train_dl = aishell.train_dataloaders( + train_dl = aishell2.train_dataloaders( train_cuts, sampler_state_dict=sampler_state_dict ) - valid_dl = aishell.valid_dataloaders(valid_cuts) + valid_dl = aishell2.valid_dataloaders(valid_cuts) if False and not params.print_diagnostics: scan_pessimistic_batches_for_oom(