Fix bug in streaming_conformer_ctc egs (#862)

* Update train.py

Fix transducer lstm egs bug as mentioned in issue 579

* Update train.py

fix dataloader bug
This commit is contained in:
BuaaAlban 2023-01-31 15:19:50 +08:00 committed by GitHub
parent e277e31e37
commit e9019511eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -50,7 +50,7 @@ from icefall.utils import (
setup_logger,
str2bool,
)
from lhotse.cut import Cut
def get_parser():
parser = argparse.ArgumentParser(
@ -645,8 +645,23 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,