Merge 723320e0159d65856d3c86ae4f75b9d44034fc3d into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Zengwei Yao 2025-08-17 06:13:43 +00:00 committed by GitHub
commit 5f66f738b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 74 additions and 2 deletions

View File

@ -1076,6 +1076,20 @@ def run(rank, world_size, args):
)
return False
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
# after encoder_embed that does T -> (T - 7) // 2
ds = tuple(map(int, params.zipformer_downsampling_factors.split(",")))
max_ds = max(ds)
T = (c.num_frames - 7) // 2
if T < max_ds:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before encoder_embed): {c.num_frames}. "
f"Number of frames (after encoder_embed): {T}. "
f"Max downsampling factor in Zipformer: {max_ds}. "
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -1101,6 +1101,20 @@ def run(rank, world_size, args):
)
return False
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
# after encoder_embed that does T -> (T - 7) // 2
ds = tuple(map(int, params.zipformer_downsampling_factors.split(",")))
max_ds = max(ds)
T = (c.num_frames - 7) // 2
if T < max_ds:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before encoder_embed): {c.num_frames}. "
f"Number of frames (after encoder_embed): {T}. "
f"Max downsampling factor in Zipformer: {max_ds}. "
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -1091,6 +1091,20 @@ def run(rank, world_size, args):
)
return False
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
# after encoder_embed that does T -> (T - 7) // 2
ds = tuple(map(int, params.zipformer_downsampling_factors.split(",")))
max_ds = max(ds)
T = (c.num_frames - 7) // 2
if T < max_ds:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before encoder_embed): {c.num_frames}. "
f"Number of frames (after encoder_embed): {T}. "
f"Max downsampling factor in Zipformer: {max_ds}. "
)
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -1099,6 +1099,20 @@ def run(rank, world_size, args):
)
return False
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
# after encoder_embed that does T -> (T - 7) // 2
ds = tuple(map(int, params.zipformer_downsampling_factors.split(",")))
max_ds = max(ds)
T = (c.num_frames - 7) // 2
if T < max_ds:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before encoder_embed): {c.num_frames}. "
f"Number of frames (after encoder_embed): {T}. "
f"Max downsampling factor in Zipformer: {max_ds}. "
)
return False
return True
# train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -1022,7 +1022,7 @@ def train_one_epoch(
def filter_short_and_long_utterances(
cuts: CutSet, sp: spm.SentencePieceProcessor
cuts: CutSet, sp: spm.SentencePieceProcessor, zipformer_downsampling_factors: str
) -> CutSet:
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1059,6 +1059,20 @@ def filter_short_and_long_utterances(
)
return False
# Zipformer has DownsampledZipformerEncoders with different downsampling factors
# after encoder_embed that does T -> (T - 7) // 2
ds = tuple(map(int, zipformer_downsampling_factors.split(",")))
max_ds = max(ds)
T = (c.num_frames - 7) // 2
if T < max_ds:
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before encoder_embed): {c.num_frames}. "
f"Number of frames (after encoder_embed): {T}. "
f"Max downsampling factor in Zipformer: {max_ds}. "
)
return False
return True
cuts = cuts.filter(remove_short_and_long_utt)
@ -1173,7 +1187,9 @@ def run(rank, world_size, args):
else:
train_cuts = librispeech.train_clean_100_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts, sp)
train_cuts = filter_short_and_long_utterances(
train_cuts, sp, params.zipformer_downsampling_factors
)
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)
# XL 10k hours