shuffle full Librispeech for zipformer recipes (#869)

* shuffle libri
This commit is contained in:
Zengwei Yao 2023-02-03 11:54:57 +08:00 committed by GitHub
parent e36ea89112
commit 1e6d6f8160
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 18 additions and 17 deletions

View File

@ -1043,10 +1043,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds

View File

@ -1072,10 +1072,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds

View File

@ -55,9 +55,9 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from frame_reducer import FrameReducer
from joiner import Joiner from joiner import Joiner
from lconv import LConv from lconv import LConv
from frame_reducer import FrameReducer
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -1063,10 +1063,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds

View File

@ -1049,10 +1049,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds

View File

@ -1154,10 +1154,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeech(manifest_dir=args.manifest_dir) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts, sp) train_cuts = filter_short_and_long_utterances(train_cuts, sp)

View File

@ -30,6 +30,7 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -50,7 +51,7 @@ from icefall.utils import (
setup_logger, setup_logger,
str2bool, str2bool,
) )
from lhotse.cut import Cut
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(