From 0904e490c5fb424dc5cb4d14ae468e4d32a07dc4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 28 Nov 2025 11:42:20 +0800 Subject: [PATCH] Fix gigaspeech dataset iterator. (#2045) Previously, it was reset after every epoch, which may cause it to always use the first part of the gigaspeech dataset if you choose a small --giga-prob. --- .../ASR/lstm_transducer_stateless2/train.py | 9 ++--- .../ASR/pruned_transducer_stateless3/train.py | 33 ++++++++++--------- .../train.py | 8 ++--- .../ASR/pruned_transducer_stateless8/train.py | 9 ++--- 4 files changed, 32 insertions(+), 27 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 1b31b5485..271b21f5e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -53,7 +53,7 @@ import random import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Tuple, Union import k2 import optim @@ -770,7 +770,7 @@ def train_one_epoch( scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, - giga_train_dl: torch.utils.data.DataLoader, + iter_giga: Iterator, valid_dl: torch.utils.data.DataLoader, rng: random.Random, scaler: "GradScaler", @@ -826,7 +826,6 @@ def train_one_epoch( dl_weights = [1 - params.giga_prob, params.giga_prob] iter_libri = iter(train_dl) - iter_giga = iter(giga_train_dl) batch_idx = 0 @@ -1177,6 +1176,8 @@ def run(rank, world_size, args): else: logging.info("Skip scan_pessimistic_batches_for_oom") + iter_giga = iter(giga_train_dl) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") @@ -1200,7 +1201,7 @@ def run(rank, world_size, args): scheduler=scheduler, sp=sp, train_dl=train_dl, - giga_train_dl=giga_train_dl, + iter_giga=iter_giga, valid_dl=valid_dl, rng=rng, scaler=scaler, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 50670d1b2..f0396eb2f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -53,7 +53,7 @@ import random import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Tuple, Union import k2 import optim @@ -753,7 +753,7 @@ def train_one_epoch( scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, - giga_train_dl: torch.utils.data.DataLoader, + iter_giga: Iterator, valid_dl: torch.utils.data.DataLoader, rng: random.Random, scaler: "GradScaler", @@ -806,7 +806,6 @@ def train_one_epoch( dl_weights = [1 - params.giga_prob, params.giga_prob] iter_libri = iter(train_dl) - iter_giga = iter(giga_train_dl) batch_idx = 0 @@ -950,9 +949,9 @@ def filter_short_and_long_utterances( # an utterance duration distribution for your dataset to select # the threshold if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) return False # In pruned RNN-T, we require that T >= S @@ -965,14 +964,14 @@ def filter_short_and_long_utterances( tokens = sp.encode(c.supervisions[0].text, out_type=str) if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) + # logging.warning( + # f"Exclude cut with ID {c.id} from training. " + # f"Number of frames (before subsampling): {c.num_frames}. " + # f"Number of frames (after subsampling): {T}. " + # f"Text: {c.supervisions[0].text}. " + # f"Tokens: {tokens}. " + # f"Number of tokens: {len(tokens)}" + # ) return False return True @@ -1117,6 +1116,8 @@ def run(rank, world_size, args): # It's time consuming to include `giga_train_dl` here # for dl in [train_dl, giga_train_dl]: for dl in [train_dl]: + # You can skip scan_pessimistic_batches_for_oom() if you are sure + # your selected params won't cause OOM if params.start_batch <= 0: scan_pessimistic_batches_for_oom( model=model, @@ -1127,6 +1128,8 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) + iter_giga = iter(giga_train_dl) + scaler = create_grad_scaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") @@ -1149,7 +1152,7 @@ def run(rank, world_size, args): scheduler=scheduler, sp=sp, train_dl=train_dl, - giga_train_dl=giga_train_dl, + iter_giga=iter_giga, valid_dl=valid_dl, rng=rng, scaler=scaler, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index 4b97575e6..1442aa121 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -50,7 +50,7 @@ import random import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Tuple, Union import k2 import optim @@ -798,7 +798,7 @@ def train_one_epoch( scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, - giga_train_dl: torch.utils.data.DataLoader, + iter_giga: Iterator, valid_dl: torch.utils.data.DataLoader, rng: random.Random, scaler: "GradScaler", @@ -849,7 +849,6 @@ def train_one_epoch( # This sets the probabilities for choosing which datasets dl_weights = [1 - params.giga_prob, params.giga_prob] iter_libri = iter(train_dl) - iter_giga = iter(giga_train_dl) batch_idx = 0 @@ -1223,6 +1222,7 @@ def run(rank, world_size, args): # sp=sp, # params=params, # ) + iter_giga = iter(giga_train_dl) scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -1247,7 +1247,7 @@ def run(rank, world_size, args): scheduler=scheduler, sp=sp, train_dl=train_dl, - giga_train_dl=giga_train_dl, + iter_giga=iter_giga, valid_dl=valid_dl, rng=rng, scaler=scaler, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index ad14ec9dc..9b372109f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -55,7 +55,7 @@ import random import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Tuple, Union import k2 import optim @@ -793,7 +793,7 @@ def train_one_epoch( scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, - giga_train_dl: torch.utils.data.DataLoader, + iter_giga: Iterator, valid_dl: torch.utils.data.DataLoader, rng: random.Random, scaler: "GradScaler", @@ -849,7 +849,6 @@ def train_one_epoch( dl_weights = [1 - params.giga_prob, params.giga_prob] iter_libri = iter(train_dl) - iter_giga = iter(giga_train_dl) batch_idx = 0 @@ -1225,6 +1224,8 @@ def run(rank, world_size, args): params=params, ) + iter_giga = iter(giga_train_dl) + scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") @@ -1248,7 +1249,7 @@ def run(rank, world_size, args): scheduler=scheduler, sp=sp, train_dl=train_dl, - giga_train_dl=giga_train_dl, + iter_giga=iter_giga, valid_dl=valid_dl, rng=rng, scaler=scaler,