use dynamicbucketsampler for decoding

This commit is contained in:
luomingshuang 2022-05-17 09:38:25 +08:00
parent 74b3861320
commit 822fe21a83
2 changed files with 3 additions and 9 deletions

View File

@ -34,6 +34,7 @@ from lhotse.dataset import (
BucketingSampler, BucketingSampler,
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SingleCutSampler,
@ -350,7 +351,7 @@ class Aidatatang_200zhAsrDataModule:
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
valid_sampler = BucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0, rank=0,
@ -382,7 +383,7 @@ class Aidatatang_200zhAsrDataModule:
else PrecomputedFeatures(), else PrecomputedFeatures(),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = BucketingSampler( sampler = DynamicBucketingSampler(
cuts, cuts,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0, rank=0,

View File

@ -508,13 +508,6 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
average = average_checkpoints(filenames, device=device)
checkpoint = {"model": average}
torch.save(
checkpoint,
"pruned_transducer_stateless2/pretrained_average_11_to_29.pt",
)
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device