fix lhotse compatibility

This commit is contained in:
marcoyang 2023-09-21 10:22:56 +08:00
parent 974c1fff08
commit 21cc1dfff4
2 changed files with 6 additions and 10 deletions

View File

@ -25,14 +25,13 @@ from typing import Any, Callable, Dict, List, Optional
import torch
from dataset import PromptASRDataset
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
from lhotse.dataset import ( # SingleCutSampler,
CutConcatenate,
CutMix,
DynamicBucketingSampler,
ExtraPadding,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
@ -211,7 +210,7 @@ class LibriHeavyAsrDataModule:
)
group.add_argument(
"--topk-k",
"--top-k",
type=int,
default=10000,
help="""The top-k words are identified as common words,
@ -261,7 +260,7 @@ class LibriHeavyAsrDataModule:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
@ -345,11 +344,8 @@ class LibriHeavyAsrDataModule:
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
raise NotImplementedError(
"SingleCutSampler is no longer supported by lhotse"
)
logging.info("About to create train dataloader")

View File

@ -205,7 +205,7 @@ def triplet_text_sampling(
rare_word_list: Optional[List[str]] = None,
transforms: Optional[List[Callable[[str], str]]] = None,
min_len_style: Optional[int] = 80,
) -> Dict[str, str, str]:
) -> Dict[str, str]:
"""This function generates a triplet of
(pre_text, style_text, ref_text). The style of style_text and ref_text
should **always** match, whereas the style of pre_text is arbitrary.