Add StatelessSampler

This commit is contained in:
Yifan Yang 2023-08-08 21:35:30 +08:00
parent 1ee251c8b3
commit 0b61e6612b
2 changed files with 101 additions and 59 deletions

View File

@ -21,7 +21,7 @@ import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Union
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
@ -33,6 +33,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
StatelessSampler,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
@ -123,6 +124,12 @@ class LibriSpeechAsrDataModule:
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--stateless-sampler",
type=str2bool,
default=False,
help="When enabled, the dataloader is completely stateless.",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
@ -217,7 +224,7 @@ class LibriSpeechAsrDataModule:
def train_dataloaders(
self,
cuts_train: CutSet,
cuts_train: Union[CutSet, List[str]],
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
@ -305,21 +312,40 @@ class LibriSpeechAsrDataModule:
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
if self.args.stateless_sampler:
logging.info("Using bucketing StatelessSampler.")
train_sampler = StatelessSampler(
cuts_train,
base_seed=0,
index_path=self.args.manifest_dir / "files.idx",
max_duration=self.args.max_duration,
num_buckets=self.args.num_buckets,
)
else:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
if self.args.stateless_sampler:
logging.info("Using non-bucketing StatelessSampler.")
train_sampler = StatelessSampler(
cuts_train,
base_seed=0,
index_path=self.args.manifest_dir / "files.idx",
max_duration=self.args.max_duration,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
@ -379,7 +405,7 @@ class LibriSpeechAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
@ -388,7 +414,7 @@ class LibriSpeechAsrDataModule:
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
cuts_test,
max_duration=self.args.max_duration,
shuffle=False,
)
@ -402,42 +428,57 @@ class LibriSpeechAsrDataModule:
return test_dl
@lru_cache()
def train_clean_5_cuts(self) -> CutSet:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
def train_clean_5_cuts(self) -> Union[CutSet, List[str]]:
if self.args.stateless_sampler:
return [self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl"]
else:
logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
@lru_cache()
def train_clean_100_cuts(self) -> CutSet:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
def train_clean_100_cuts(self) -> Union[CutSet, List[str]]:
if self.args.stateless_sampler:
return [self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl"]
else:
logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache()
def train_clean_360_cuts(self) -> CutSet:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
def train_clean_360_cuts(self) -> Union[CutSet, List[str]]:
if self.args.stateless_sampler:
return [self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl"]
else:
logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache()
def train_other_500_cuts(self) -> CutSet:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
def train_other_500_cuts(self) -> Union[CutSet, List[str]]:
if self.args.stateless_sampler:
return [self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl"]
else:
logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache()
def train_all_shuf_cuts(self) -> CutSet:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
def train_all_shuf_cuts(self) -> Union[CutSet, List[str]]:
if self.args.stateless_sampler:
return [self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl"]
else:
logging.info(
"About to get the shuffled train-clean-100, \
train-clean-360 and train-other-500 cuts"
)
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache()
def dev_clean_2_cuts(self) -> CutSet:

View File

@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
assert (
params.use_transducer or params.use_ctc
), (f"At least one of them should be True, "
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}")
f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
@ -808,17 +808,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += (
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
@ -1217,7 +1216,8 @@ def run(rank, world_size, args):
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if not params.stateless_sampler:
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1234,7 +1234,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
@ -1251,7 +1251,8 @@ def run(rank, world_size, args):
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
if not params.stateless_sampler:
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)