use dynamicbucketsampler for decoding

This commit is contained in:
luomingshuang 2022-05-17 09:38:25 +08:00
parent 74b3861320
commit 822fe21a83
2 changed files with 3 additions and 9 deletions

View File

@ -34,6 +34,7 @@ from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
@ -350,7 +351,7 @@ class Aidatatang_200zhAsrDataModule:
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = BucketingSampler(
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
rank=0,
@ -382,7 +383,7 @@ class Aidatatang_200zhAsrDataModule:
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = BucketingSampler(
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
rank=0,

View File

@ -508,13 +508,6 @@ def main():
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
average = average_checkpoints(filenames, device=device)
checkpoint = {"model": average}
torch.save(
checkpoint,
"pruned_transducer_stateless2/pretrained_average_11_to_29.pt",
)
model.to(device)
model.eval()
model.device = device