mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
60a2837b5f
commit
29834f7a27
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -31,7 +31,10 @@ from lhotse.dataset import (
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.dataset.input_strategies import (
|
||||
OnTheFlyFeatures,
|
||||
AudioSamples,
|
||||
)
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
@ -168,6 +171,18 @@ class TedLiumAsrDataModule:
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="AudioSamples",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spk-id",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
@ -238,13 +253,15 @@ class TedLiumAsrDataModule:
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
@ -295,12 +312,14 @@ class TedLiumAsrDataModule:
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
@ -326,11 +345,13 @@ class TedLiumAsrDataModule:
|
||||
logging.debug("About to create test dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
@ -368,6 +389,6 @@ class TedLiumAsrDataModule:
|
||||
return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def user_test_cuts(self, user) -> CutSet:
|
||||
def user_test_cuts(self, spk_id) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{user}.jsonl.gz")
|
||||
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{spk_id}.jsonl.gz")
|
||||
|
||||
@ -20,27 +20,6 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless7_ctc/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless7_ctc/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless7_ctc/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7_ctc/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
# For d2v-T training:
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
@ -87,7 +66,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from asr_datamodule import TedLiumAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
@ -1365,7 +1344,8 @@ def run(rank, world_size, args, wb=None):
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
#librispeech = LibriSpeechAsrDataModule(args)
|
||||
ted = TedLiumAsrDataModule(args)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user