mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Add StatelessSampler
This commit is contained in:
parent
1ee251c8b3
commit
0b61e6612b
@ -21,7 +21,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
@ -33,6 +33,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
|||||||
PrecomputedFeatures,
|
PrecomputedFeatures,
|
||||||
SingleCutSampler,
|
SingleCutSampler,
|
||||||
SpecAugment,
|
SpecAugment,
|
||||||
|
StatelessSampler,
|
||||||
)
|
)
|
||||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
AudioSamples,
|
AudioSamples,
|
||||||
@ -123,6 +124,12 @@ class LibriSpeechAsrDataModule:
|
|||||||
help="The number of buckets for the DynamicBucketingSampler"
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
"(you might want to increase it for larger datasets).",
|
"(you might want to increase it for larger datasets).",
|
||||||
)
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--stateless-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, the dataloader is completely stateless.",
|
||||||
|
)
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--concatenate-cuts",
|
"--concatenate-cuts",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -217,7 +224,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self,
|
||||||
cuts_train: CutSet,
|
cuts_train: Union[CutSet, List[str]],
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
@ -305,21 +312,40 @@ class LibriSpeechAsrDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.args.bucketing_sampler:
|
if self.args.bucketing_sampler:
|
||||||
logging.info("Using DynamicBucketingSampler.")
|
if self.args.stateless_sampler:
|
||||||
train_sampler = DynamicBucketingSampler(
|
logging.info("Using bucketing StatelessSampler.")
|
||||||
cuts_train,
|
train_sampler = StatelessSampler(
|
||||||
max_duration=self.args.max_duration,
|
cuts_train,
|
||||||
shuffle=self.args.shuffle,
|
base_seed=0,
|
||||||
num_buckets=self.args.num_buckets,
|
index_path=self.args.manifest_dir / "files.idx",
|
||||||
drop_last=self.args.drop_last,
|
max_duration=self.args.max_duration,
|
||||||
)
|
num_buckets=self.args.num_buckets,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
|
train_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SingleCutSampler.")
|
if self.args.stateless_sampler:
|
||||||
train_sampler = SingleCutSampler(
|
logging.info("Using non-bucketing StatelessSampler.")
|
||||||
cuts_train,
|
train_sampler = StatelessSampler(
|
||||||
max_duration=self.args.max_duration,
|
cuts_train,
|
||||||
shuffle=self.args.shuffle,
|
base_seed=0,
|
||||||
)
|
index_path=self.args.manifest_dir / "files.idx",
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SingleCutSampler.")
|
||||||
|
train_sampler = SingleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
logging.info("About to create train dataloader")
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
if sampler_state_dict is not None:
|
if sampler_state_dict is not None:
|
||||||
@ -379,7 +405,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
|
|
||||||
return valid_dl
|
return valid_dl
|
||||||
|
|
||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = K2SpeechRecognitionDataset(
|
test = K2SpeechRecognitionDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
@ -388,7 +414,7 @@ class LibriSpeechAsrDataModule:
|
|||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts,
|
cuts_test,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
@ -402,42 +428,57 @@ class LibriSpeechAsrDataModule:
|
|||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_5_cuts(self) -> CutSet:
|
def train_clean_5_cuts(self) -> Union[CutSet, List[str]]:
|
||||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
if self.args.stateless_sampler:
|
||||||
return load_manifest_lazy(
|
return [self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl"]
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
else:
|
||||||
)
|
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_100_cuts(self) -> CutSet:
|
def train_clean_100_cuts(self) -> Union[CutSet, List[str]]:
|
||||||
logging.info("About to get train-clean-100 cuts")
|
if self.args.stateless_sampler:
|
||||||
return load_manifest_lazy(
|
return [self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl"]
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
else:
|
||||||
)
|
logging.info("About to get train-clean-100 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_clean_360_cuts(self) -> CutSet:
|
def train_clean_360_cuts(self) -> Union[CutSet, List[str]]:
|
||||||
logging.info("About to get train-clean-360 cuts")
|
if self.args.stateless_sampler:
|
||||||
return load_manifest_lazy(
|
return [self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl"]
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
else:
|
||||||
)
|
logging.info("About to get train-clean-360 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_other_500_cuts(self) -> CutSet:
|
def train_other_500_cuts(self) -> Union[CutSet, List[str]]:
|
||||||
logging.info("About to get train-other-500 cuts")
|
if self.args.stateless_sampler:
|
||||||
return load_manifest_lazy(
|
return [self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl"]
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
else:
|
||||||
)
|
logging.info("About to get train-other-500 cuts")
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_all_shuf_cuts(self) -> CutSet:
|
def train_all_shuf_cuts(self) -> Union[CutSet, List[str]]:
|
||||||
logging.info(
|
if self.args.stateless_sampler:
|
||||||
"About to get the shuffled train-clean-100, \
|
return [self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl"]
|
||||||
train-clean-360 and train-other-500 cuts"
|
else:
|
||||||
)
|
logging.info(
|
||||||
return load_manifest_lazy(
|
"About to get the shuffled train-clean-100, \
|
||||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
train-clean-360 and train-other-500 cuts"
|
||||||
)
|
)
|
||||||
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_clean_2_cuts(self) -> CutSet:
|
def dev_clean_2_cuts(self) -> CutSet:
|
||||||
|
@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
params.use_transducer or params.use_ctc
|
f"At least one of them should be True, "
|
||||||
), (f"At least one of them should be True, "
|
|
||||||
f"but got params.use_transducer={params.use_transducer}, "
|
f"but got params.use_transducer={params.use_transducer}, "
|
||||||
f"params.use_ctc={params.use_ctc}")
|
f"params.use_ctc={params.use_ctc}"
|
||||||
|
)
|
||||||
|
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
@ -808,17 +808,16 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
loss += (
|
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss
|
|
||||||
+ pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
@ -1217,7 +1216,8 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
if not params.stateless_sampler:
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
@ -1234,7 +1234,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if 0 and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
@ -1251,7 +1251,8 @@ def run(rank, world_size, args):
|
|||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
scheduler.step_epoch(epoch - 1)
|
scheduler.step_epoch(epoch - 1)
|
||||||
fix_random_seed(params.seed + epoch - 1)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
if not params.stateless_sampler:
|
||||||
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user