mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
use dynamicbucketsampler for decoding
This commit is contained in:
parent
74b3861320
commit
822fe21a83
@ -34,6 +34,7 @@ from lhotse.dataset import (
|
|||||||
BucketingSampler,
|
BucketingSampler,
|
||||||
CutConcatenate,
|
CutConcatenate,
|
||||||
CutMix,
|
CutMix,
|
||||||
|
DynamicBucketingSampler,
|
||||||
K2SpeechRecognitionDataset,
|
K2SpeechRecognitionDataset,
|
||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
@ -350,7 +351,7 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
valid_sampler = BucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
rank=0,
|
rank=0,
|
||||||
@ -382,7 +383,7 @@ class Aidatatang_200zhAsrDataModule:
|
|||||||
else PrecomputedFeatures(),
|
else PrecomputedFeatures(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = BucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts,
|
cuts,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
rank=0,
|
rank=0,
|
||||||
|
@ -508,13 +508,6 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
|
||||||
average = average_checkpoints(filenames, device=device)
|
|
||||||
checkpoint = {"model": average}
|
|
||||||
torch.save(
|
|
||||||
checkpoint,
|
|
||||||
"pruned_transducer_stateless2/pretrained_average_11_to_29.pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
model.device = device
|
||||||
|
Loading…
x
Reference in New Issue
Block a user