mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
isort and black formatting
This commit is contained in:
parent
73b30aeda5
commit
57633e1eb0
@ -55,7 +55,9 @@ def is_cut_long(c: MonoCut) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def compute_fbank_musan(
|
def compute_fbank_musan(
|
||||||
num_mel_bins: int = 80, whisper_fbank: bool = False, output_dir: str = "data/manifests"
|
num_mel_bins: int = 80,
|
||||||
|
whisper_fbank: bool = False,
|
||||||
|
output_dir: str = "data/manifests",
|
||||||
):
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
@ -180,8 +180,8 @@ class ReazonSpeechAsrDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: CutSet,
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
cuts_musan: Optional[CutSet] = None,
|
cuts_musan: Optional[CutSet] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
|
@ -65,10 +65,10 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
from asr_datamodule import ReazonSpeechAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
|
from lhotse import load_manifest
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from lhotse import load_manifest
|
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
@ -1226,14 +1226,16 @@ def run(rank, world_size, args):
|
|||||||
cuts_musan = load_manifest(musan_path)
|
cuts_musan = load_manifest(musan_path)
|
||||||
logging.info(f"Loaded MUSAN manifest from {musan_path}")
|
logging.info(f"Loaded MUSAN manifest from {musan_path}")
|
||||||
else:
|
else:
|
||||||
logging.warning(f"MUSAN manifest not found at {musan_path}, disabling MUSAN augmentation")
|
logging.warning(
|
||||||
|
f"MUSAN manifest not found at {musan_path}, disabling MUSAN augmentation"
|
||||||
|
)
|
||||||
cuts_musan = None
|
cuts_musan = None
|
||||||
else:
|
else:
|
||||||
cuts_musan = None
|
cuts_musan = None
|
||||||
|
|
||||||
train_dl = reazonspeech_corpus.train_dataloaders(
|
train_dl = reazonspeech_corpus.train_dataloaders(
|
||||||
train_cuts,
|
train_cuts,
|
||||||
sampler_state_dict=sampler_state_dict,
|
sampler_state_dict=sampler_state_dict,
|
||||||
cuts_musan=cuts_musan,
|
cuts_musan=cuts_musan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user