fix lhotse compatibility

This commit is contained in:
marcoyang 2023-09-21 10:22:56 +08:00
parent 974c1fff08
commit 21cc1dfff4
2 changed files with 6 additions and 10 deletions

View File

@ -25,14 +25,13 @@ from typing import Any, Callable, Dict, List, Optional
import torch import torch
from dataset import PromptASRDataset from dataset import PromptASRDataset
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( from lhotse.dataset import ( # SingleCutSampler,
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
ExtraPadding, ExtraPadding,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -211,7 +210,7 @@ class LibriHeavyAsrDataModule:
) )
group.add_argument( group.add_argument(
"--topk-k", "--top-k",
type=int, type=int,
default=10000, default=10000,
help="""The top-k words are identified as common words, help="""The top-k words are identified as common words,
@ -261,7 +260,7 @@ class LibriHeavyAsrDataModule:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append( transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
) )
else: else:
logging.info("Disable MUSAN") logging.info("Disable MUSAN")
@ -345,11 +344,8 @@ class LibriHeavyAsrDataModule:
drop_last=True, drop_last=True,
) )
else: else:
logging.info("Using SingleCutSampler.") raise NotImplementedError(
train_sampler = SingleCutSampler( "SingleCutSampler is no longer supported by lhotse"
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")

View File

@ -205,7 +205,7 @@ def triplet_text_sampling(
rare_word_list: Optional[List[str]] = None, rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None, transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80, min_len_style: Optional[int] = 80,
) -> Dict[str, str, str]: ) -> Dict[str, str]:
"""This function generates a triplet of """This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text (pre_text, style_text, ref_text). The style of style_text and ref_text
should **always** match, whereas the style of pre_text is arbitrary. should **always** match, whereas the style of pre_text is arbitrary.