Add StatelessSampler

This commit is contained in:
Yifan Yang 2023-08-08 21:35:30 +08:00
parent 1ee251c8b3
commit 0b61e6612b
2 changed files with 101 additions and 59 deletions

View File

@ -21,7 +21,7 @@ import inspect
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional, Union
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
@ -33,6 +33,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SingleCutSampler,
SpecAugment, SpecAugment,
StatelessSampler,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples, AudioSamples,
@ -123,6 +124,12 @@ class LibriSpeechAsrDataModule:
help="The number of buckets for the DynamicBucketingSampler" help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).", "(you might want to increase it for larger datasets).",
) )
group.add_argument(
"--stateless-sampler",
type=str2bool,
default=False,
help="When enabled, the dataloader is completely stateless.",
)
group.add_argument( group.add_argument(
"--concatenate-cuts", "--concatenate-cuts",
type=str2bool, type=str2bool,
@ -217,7 +224,7 @@ class LibriSpeechAsrDataModule:
def train_dataloaders( def train_dataloaders(
self, self,
cuts_train: CutSet, cuts_train: Union[CutSet, List[str]],
sampler_state_dict: Optional[Dict[str, Any]] = None, sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader: ) -> DataLoader:
""" """
@ -305,21 +312,40 @@ class LibriSpeechAsrDataModule:
) )
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.") if self.args.stateless_sampler:
train_sampler = DynamicBucketingSampler( logging.info("Using bucketing StatelessSampler.")
cuts_train, train_sampler = StatelessSampler(
max_duration=self.args.max_duration, cuts_train,
shuffle=self.args.shuffle, base_seed=0,
num_buckets=self.args.num_buckets, index_path=self.args.manifest_dir / "files.idx",
drop_last=self.args.drop_last, max_duration=self.args.max_duration,
) num_buckets=self.args.num_buckets,
)
else:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else: else:
logging.info("Using SingleCutSampler.") if self.args.stateless_sampler:
train_sampler = SingleCutSampler( logging.info("Using non-bucketing StatelessSampler.")
cuts_train, train_sampler = StatelessSampler(
max_duration=self.args.max_duration, cuts_train,
shuffle=self.args.shuffle, base_seed=0,
) index_path=self.args.manifest_dir / "files.idx",
max_duration=self.args.max_duration,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
if sampler_state_dict is not None: if sampler_state_dict is not None:
@ -379,7 +405,7 @@ class LibriSpeechAsrDataModule:
return valid_dl return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
@ -388,7 +414,7 @@ class LibriSpeechAsrDataModule:
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, cuts_test,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
) )
@ -402,42 +428,57 @@ class LibriSpeechAsrDataModule:
return test_dl return test_dl
@lru_cache() @lru_cache()
def train_clean_5_cuts(self) -> CutSet: def train_clean_5_cuts(self) -> Union[CutSet, List[str]]:
logging.info("mini_librispeech: About to get train-clean-5 cuts") if self.args.stateless_sampler:
return load_manifest_lazy( return [self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl"]
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz" else:
) logging.info("mini_librispeech: About to get train-clean-5 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
)
@lru_cache() @lru_cache()
def train_clean_100_cuts(self) -> CutSet: def train_clean_100_cuts(self) -> Union[CutSet, List[str]]:
logging.info("About to get train-clean-100 cuts") if self.args.stateless_sampler:
return load_manifest_lazy( return [self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl"]
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" else:
) logging.info("About to get train-clean-100 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
)
@lru_cache() @lru_cache()
def train_clean_360_cuts(self) -> CutSet: def train_clean_360_cuts(self) -> Union[CutSet, List[str]]:
logging.info("About to get train-clean-360 cuts") if self.args.stateless_sampler:
return load_manifest_lazy( return [self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl"]
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" else:
) logging.info("About to get train-clean-360 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
)
@lru_cache() @lru_cache()
def train_other_500_cuts(self) -> CutSet: def train_other_500_cuts(self) -> Union[CutSet, List[str]]:
logging.info("About to get train-other-500 cuts") if self.args.stateless_sampler:
return load_manifest_lazy( return [self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl"]
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" else:
) logging.info("About to get train-other-500 cuts")
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
)
@lru_cache() @lru_cache()
def train_all_shuf_cuts(self) -> CutSet: def train_all_shuf_cuts(self) -> Union[CutSet, List[str]]:
logging.info( if self.args.stateless_sampler:
"About to get the shuffled train-clean-100, \ return [self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl"]
train-clean-360 and train-other-500 cuts" else:
) logging.info(
return load_manifest_lazy( "About to get the shuffled train-clean-100, \
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" train-clean-360 and train-other-500 cuts"
) )
return load_manifest_lazy(
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
)
@lru_cache() @lru_cache()
def dev_clean_2_cuts(self) -> CutSet: def dev_clean_2_cuts(self) -> CutSet:

View File

@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert ( assert params.use_transducer or params.use_ctc, (
params.use_transducer or params.use_ctc f"At least one of them should be True, "
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -808,17 +808,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += ( loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
@ -1217,7 +1216,8 @@ def run(rank, world_size, args):
return True return True
train_cuts = train_cuts.filter(remove_short_and_long_utt) if not params.stateless_sampler:
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
@ -1234,7 +1234,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
@ -1251,7 +1251,8 @@ def run(rank, world_size, args):
for epoch in range(params.start_epoch, params.num_epochs + 1): for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1) scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1) fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1) if not params.stateless_sampler:
train_dl.sampler.set_epoch(epoch - 1)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)