mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
60a2837b5f
commit
29834f7a27
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -31,7 +31,10 @@ from lhotse.dataset import (
|
|||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
from lhotse.dataset.input_strategies import (
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
AudioSamples,
|
||||||
|
)
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
@ -168,6 +171,18 @@ class TedLiumAsrDataModule:
|
|||||||
help="When enabled, select noise from MUSAN and mix it"
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
"with training dataset.",
|
"with training dataset.",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="AudioSamples",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--spk-id",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||||
@ -238,13 +253,15 @@ class TedLiumAsrDataModule:
|
|||||||
# Drop feats to be on the safe side.
|
# Drop feats to be on the safe side.
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
input_transforms=input_transforms,
|
input_transforms=input_transforms,
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
@ -295,12 +312,14 @@ class TedLiumAsrDataModule:
|
|||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -326,11 +345,13 @@ class TedLiumAsrDataModule:
|
|||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
if self.args.on_the_fly_feats:
|
if self.args.on_the_fly_feats:
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
#input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -368,6 +389,6 @@ class TedLiumAsrDataModule:
|
|||||||
return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz")
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def user_test_cuts(self, user) -> CutSet:
|
def user_test_cuts(self, spk_id) -> CutSet:
|
||||||
logging.info("About to get test cuts")
|
logging.info("About to get test cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{user}.jsonl.gz")
|
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{spk_id}.jsonl.gz")
|
||||||
|
|||||||
@ -20,27 +20,6 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7_ctc/train.py \
|
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 1 \
|
|
||||||
--exp-dir pruned_transducer_stateless7_ctc/exp \
|
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 300
|
|
||||||
|
|
||||||
# For mix precision training:
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7_ctc/train.py \
|
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 1 \
|
|
||||||
--use-fp16 1 \
|
|
||||||
--exp-dir pruned_transducer_stateless7_ctc/exp \
|
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 550
|
|
||||||
|
|
||||||
# For d2v-T training:
|
# For d2v-T training:
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||||
|
|
||||||
@ -87,7 +66,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import TedLiumAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -1365,7 +1344,8 @@ def run(rank, world_size, args, wb=None):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
#librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
ted = TedLiumAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user