From 822fe21a837ba6ed8e888283fb2a22d5f31d5389 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Tue, 17 May 2022 09:38:25 +0800 Subject: [PATCH] use dynamicbucketsampler for decoding --- .../ASR/pruned_transducer_stateless2/asr_datamodule.py | 5 +++-- .../ASR/pruned_transducer_stateless2/decode.py | 7 ------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index de83cef5a..343baf65a 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -34,6 +34,7 @@ from lhotse.dataset import ( BucketingSampler, CutConcatenate, CutMix, + DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, @@ -350,7 +351,7 @@ class Aidatatang_200zhAsrDataModule: cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = BucketingSampler( + valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, rank=0, @@ -382,7 +383,7 @@ class Aidatatang_200zhAsrDataModule: else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = BucketingSampler( + sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, rank=0, diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index 83a442b90..b78c600c3 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -508,13 +508,6 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) - average = average_checkpoints(filenames, device=device) - checkpoint = {"model": average} - torch.save( - checkpoint, - "pruned_transducer_stateless2/pretrained_average_11_to_29.pt", - ) - model.to(device) model.eval() model.device = device