From ba6794265629e84a2184860f026935412d609555 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Wed, 9 Aug 2023 16:27:36 +0800 Subject: [PATCH] minor updates --- .../asr_datamodule.py | 3 ++- egs/aidatatang_200zh/ASR/zipformer/decode.py | 24 ++++++++----------- .../ASR/zipformer/streaming_decode.py | 21 +++++++++------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 167d5e15e..90d319a2b 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -102,7 +102,7 @@ class Aidatatang_200zhAsrDataModule: group.add_argument( "--bucketing-sampler", type=str2bool, - default=True, + default=False, help="When enabled, the batches will come from buckets of " "similar duration (saves padding frames).", ) @@ -289,6 +289,7 @@ class Aidatatang_200zhAsrDataModule: shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=True, + buffer_size=50000, ) else: logging.info("Using SingleCutSampler.") diff --git a/egs/aidatatang_200zh/ASR/zipformer/decode.py b/egs/aidatatang_200zh/ASR/zipformer/decode.py index 0fbc8244b..8a6d2845a 100755 --- a/egs/aidatatang_200zh/ASR/zipformer/decode.py +++ b/egs/aidatatang_200zh/ASR/zipformer/decode.py @@ -88,7 +88,7 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule +from asr_datamodule import Aidatatang_200zhAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -596,7 +596,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) + Aidatatang_200zhAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -770,7 +770,7 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - wenetspeech = WenetSpeechAsrDataModule(args) + aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) def remove_short_utt(c: Cut): T = ((c.num_frames - 7) // 2 + 1) // 2 @@ -780,20 +780,16 @@ def main(): ) return T > 0 - dev_cuts = wenetspeech.valid_cuts() + dev_cuts = aidatatang_200zh.valid_cuts() dev_cuts = dev_cuts.filter(remove_short_utt) - dev_dl = wenetspeech.valid_dataloaders(dev_cuts) + dev_dl = aidatatang_200zh.valid_dataloaders(dev_cuts) - test_net_cuts = wenetspeech.test_net_cuts() - test_net_cuts = test_net_cuts.filter(remove_short_utt) - test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + test_cuts = aidatatang_200zh.test_cuts() + test_cuts = test_cuts.filter(remove_short_utt) + test_dl = aidatatang_200zh.test_dataloaders(test_cuts) - test_meeting_cuts = wenetspeech.test_meeting_cuts() - test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) - test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dls = [dev_dl, test_net_dl, test_meeting_dl] + test_sets = ["valid_cuts", "test_cuts"] + test_dls = [dev_dl, test_dl] for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset( diff --git a/egs/aidatatang_200zh/ASR/zipformer/streaming_decode.py b/egs/aidatatang_200zh/ASR/zipformer/streaming_decode.py index 94c5fae5f..f95f02cad 100755 --- a/egs/aidatatang_200zh/ASR/zipformer/streaming_decode.py +++ b/egs/aidatatang_200zh/ASR/zipformer/streaming_decode.py @@ -39,7 +39,7 @@ from typing import Dict, List, Optional, Tuple import k2 import numpy as np import torch -from asr_datamodule import WenetSpeechAsrDataModule +from asr_datamodule import Aidatatang_200zhAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet @@ -386,7 +386,11 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, @@ -713,7 +717,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) + Aidatatang_200zhAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -851,14 +855,13 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - wenetspeech = WenetSpeechAsrDataModule(args) + aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) - dev_cuts = wenetspeech.valid_cuts() - test_net_cuts = wenetspeech.test_net_cuts() - test_meeting_cuts = wenetspeech.test_meeting_cuts() + dev_cuts = aidatatang_200zh.valid_cuts() + test_cuts = aidatatang_200zh.test_cuts() - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts] + test_sets = ["valid_cuts", "test_cuts"] + test_cuts = [dev_cuts, test_cuts] for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset(