mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
add length filter condition
This commit is contained in:
parent
e84630adf2
commit
723320e015
@ -1064,6 +1064,20 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
|
||||||
|
# after encoder_embed that does T -> (T - 7) // 2
|
||||||
|
ds = tuple(map(int, params.zipformer_downsampling_factors.split(",")))
|
||||||
|
max_ds = max(ds)
|
||||||
|
T = (c.num_frames - 7) // 2
|
||||||
|
if T < max_ds:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before encoder_embed): {c.num_frames}. "
|
||||||
|
f"Number of frames (after encoder_embed): {T}. "
|
||||||
|
f"Max downsampling factor in Zipformer: {max_ds}. "
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|||||||
@ -1112,6 +1112,20 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
|
||||||
|
# after encoder_embed that does T -> (T - 7) // 2
|
||||||
|
ds = tuple(map(int, params.zipformer_downsampling_factors.split(",")))
|
||||||
|
max_ds = max(ds)
|
||||||
|
T = (c.num_frames - 7) // 2
|
||||||
|
if T < max_ds:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before encoder_embed): {c.num_frames}. "
|
||||||
|
f"Number of frames (after encoder_embed): {T}. "
|
||||||
|
f"Max downsampling factor in Zipformer: {max_ds}. "
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|||||||
@ -55,9 +55,9 @@ import torch.multiprocessing as mp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from frame_reducer import FrameReducer
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lconv import LConv
|
from lconv import LConv
|
||||||
from frame_reducer import FrameReducer
|
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -1103,6 +1103,20 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
|
||||||
|
# after encoder_embed that does T -> (T - 7) // 2
|
||||||
|
ds = tuple(map(int, params.zipformer_downsampling_factors.split(",")))
|
||||||
|
max_ds = max(ds)
|
||||||
|
T = (c.num_frames - 7) // 2
|
||||||
|
if T < max_ds:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before encoder_embed): {c.num_frames}. "
|
||||||
|
f"Number of frames (after encoder_embed): {T}. "
|
||||||
|
f"Max downsampling factor in Zipformer: {max_ds}. "
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|||||||
@ -1089,6 +1089,20 @@ def run(rank, world_size, args):
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
|
||||||
|
# after encoder_embed that does T -> (T - 7) // 2
|
||||||
|
ds = tuple(map(int, params.zipformer_downsampling_factors.split(",")))
|
||||||
|
max_ds = max(ds)
|
||||||
|
T = (c.num_frames - 7) // 2
|
||||||
|
if T < max_ds:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before encoder_embed): {c.num_frames}. "
|
||||||
|
f"Number of frames (after encoder_embed): {T}. "
|
||||||
|
f"Max downsampling factor in Zipformer: {max_ds}. "
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|||||||
@ -1017,7 +1017,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
|
|
||||||
def filter_short_and_long_utterances(
|
def filter_short_and_long_utterances(
|
||||||
cuts: CutSet, sp: spm.SentencePieceProcessor
|
cuts: CutSet, sp: spm.SentencePieceProcessor, zipformer_downsampling_factors: str
|
||||||
) -> CutSet:
|
) -> CutSet:
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -1054,6 +1054,20 @@ def filter_short_and_long_utterances(
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
|
||||||
|
# after encoder_embed that does T -> (T - 7) // 2
|
||||||
|
ds = tuple(map(int, zipformer_downsampling_factors.split(",")))
|
||||||
|
max_ds = max(ds)
|
||||||
|
T = (c.num_frames - 7) // 2
|
||||||
|
if T < max_ds:
|
||||||
|
logging.warning(
|
||||||
|
f"Exclude cut with ID {c.id} from training. "
|
||||||
|
f"Number of frames (before encoder_embed): {c.num_frames}. "
|
||||||
|
f"Number of frames (after encoder_embed): {T}. "
|
||||||
|
f"Max downsampling factor in Zipformer: {max_ds}. "
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
cuts = cuts.filter(remove_short_and_long_utt)
|
cuts = cuts.filter(remove_short_and_long_utt)
|
||||||
@ -1159,7 +1173,9 @@ def run(rank, world_size, args):
|
|||||||
train_cuts += librispeech.train_clean_360_cuts()
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
train_cuts += librispeech.train_other_500_cuts()
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
|
|
||||||
train_cuts = filter_short_and_long_utterances(train_cuts, sp)
|
train_cuts = filter_short_and_long_utterances(
|
||||||
|
train_cuts, sp, params.zipformer_downsampling_factors
|
||||||
|
)
|
||||||
|
|
||||||
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
|
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
|
||||||
# XL 10k hours
|
# XL 10k hours
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user