Merge branch 'k2-fsa:master' into repeat-k

This commit is contained in:
Yifan Yang 2023-02-03 11:59:22 +08:00 committed by GitHub
commit b27c11c867
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 18 additions and 17 deletions

View File

@ -1043,10 +1043,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds

View File

@ -1072,10 +1072,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds

View File

@ -55,9 +55,9 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from frame_reducer import FrameReducer
from joiner import Joiner from joiner import Joiner
from lconv import LConv from lconv import LConv
from frame_reducer import FrameReducer
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -1063,10 +1063,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds

View File

@ -1049,10 +1049,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds

View File

@ -1154,10 +1154,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeech(manifest_dir=args.manifest_dir) librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts, sp) train_cuts = filter_short_and_long_utterances(train_cuts, sp)

View File

@ -30,6 +30,7 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -50,7 +51,7 @@ from icefall.utils import (
setup_logger, setup_logger,
str2bool, str2bool,
) )
from lhotse.cut import Cut
def get_parser(): def get_parser():