from local

This commit is contained in:
dohe0342 2023-06-09 16:16:36 +09:00
parent 60a2837b5f
commit 29834f7a27
6 changed files with 30 additions and 29 deletions

Binary file not shown.

View File

@ -31,7 +31,10 @@ from lhotse.dataset import (
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
AudioSamples,
)
from torch.utils.data import DataLoader
from icefall.utils import str2bool
@ -168,6 +171,18 @@ class TedLiumAsrDataModule:
help="When enabled, select noise from MUSAN and mix it"
"with training dataset.",
)
group.add_argument(
"--input-strategy",
type=str,
default="AudioSamples",
help="AudioSamples or PrecomputedFeatures",
)
group.add_argument(
"--spk-id",
type=int,
default=0,
)
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
@ -238,13 +253,15 @@ class TedLiumAsrDataModule:
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=eval(self.args.input_strategy)(),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=eval(self.args.input_strategy)(),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
@ -295,12 +312,14 @@ class TedLiumAsrDataModule:
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=eval(self.args.input_strategy)(),
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -326,11 +345,13 @@ class TedLiumAsrDataModule:
logging.debug("About to create test dataset")
if self.args.on_the_fly_feats:
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
else:
test = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
@ -368,6 +389,6 @@ class TedLiumAsrDataModule:
return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
@lru_cache()
def user_test_cuts(self, user) -> CutSet:
def user_test_cuts(self, spk_id) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{user}.jsonl.gz")
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{spk_id}.jsonl.gz")

View File

@ -20,27 +20,6 @@
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7_ctc/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7_ctc/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless7_ctc/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7_ctc/exp \
--full-libri 1 \
--max-duration 550
# For d2v-T training:
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
@ -87,7 +66,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import TedLiumAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -1365,7 +1344,8 @@ def run(rank, world_size, args, wb=None):
if params.inf_check:
register_inf_check_hooks(model)
librispeech = LibriSpeechAsrDataModule(args)
#librispeech = LibriSpeechAsrDataModule(args)
ted = TedLiumAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: