diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index de83cef5a..343baf65a 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -34,6 +34,7 @@ from lhotse.dataset import ( BucketingSampler, CutConcatenate, CutMix, + DynamicBucketingSampler, K2SpeechRecognitionDataset, PrecomputedFeatures, SingleCutSampler, @@ -350,7 +351,7 @@ class Aidatatang_200zhAsrDataModule: cut_transforms=transforms, return_cuts=self.args.return_cuts, ) - valid_sampler = BucketingSampler( + valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, rank=0, @@ -382,7 +383,7 @@ class Aidatatang_200zhAsrDataModule: else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = BucketingSampler( + sampler = DynamicBucketingSampler( cuts, max_duration=self.args.max_duration, rank=0, diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index 83a442b90..b78c600c3 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -508,13 +508,6 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) - average = average_checkpoints(filenames, device=device) - checkpoint = {"model": average} - torch.save( - checkpoint, - "pruned_transducer_stateless2/pretrained_average_11_to_29.pt", - ) - model.to(device) model.eval() model.device = device