Fix gigaspeech dataset iterator. (#2045)

Previously, it was reset after every epoch, which may cause it to
always use the first part of the gigaspeech dataset if you choose
a small --giga-prob.
This commit is contained in:
Fangjun Kuang 2025-11-28 11:42:20 +08:00 committed by GitHub
parent 693f069de7
commit 0904e490c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 32 additions and 27 deletions

View File

@ -53,7 +53,7 @@ import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Tuple, Union
import k2
import optim
@ -770,7 +770,7 @@ def train_one_epoch(
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader,
iter_giga: Iterator,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: "GradScaler",
@ -826,7 +826,6 @@ def train_one_epoch(
dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0
@ -1177,6 +1176,8 @@ def run(rank, world_size, args):
else:
logging.info("Skip scan_pessimistic_batches_for_oom")
iter_giga = iter(giga_train_dl)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
@ -1200,7 +1201,7 @@ def run(rank, world_size, args):
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
giga_train_dl=giga_train_dl,
iter_giga=iter_giga,
valid_dl=valid_dl,
rng=rng,
scaler=scaler,

View File

@ -53,7 +53,7 @@ import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Tuple, Union
import k2
import optim
@ -753,7 +753,7 @@ def train_one_epoch(
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader,
iter_giga: Iterator,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: "GradScaler",
@ -806,7 +806,6 @@ def train_one_epoch(
dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0
@ -950,9 +949,9 @@ def filter_short_and_long_utterances(
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False
# In pruned RNN-T, we require that T >= S
@ -965,14 +964,14 @@ def filter_short_and_long_utterances(
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
# logging.warning(
# f"Exclude cut with ID {c.id} from training. "
# f"Number of frames (before subsampling): {c.num_frames}. "
# f"Number of frames (after subsampling): {T}. "
# f"Text: {c.supervisions[0].text}. "
# f"Tokens: {tokens}. "
# f"Number of tokens: {len(tokens)}"
# )
return False
return True
@ -1117,6 +1116,8 @@ def run(rank, world_size, args):
# It's time consuming to include `giga_train_dl` here
# for dl in [train_dl, giga_train_dl]:
for dl in [train_dl]:
# You can skip scan_pessimistic_batches_for_oom() if you are sure
# your selected params won't cause OOM
if params.start_batch <= 0:
scan_pessimistic_batches_for_oom(
model=model,
@ -1127,6 +1128,8 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
iter_giga = iter(giga_train_dl)
scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
@ -1149,7 +1152,7 @@ def run(rank, world_size, args):
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
giga_train_dl=giga_train_dl,
iter_giga=iter_giga,
valid_dl=valid_dl,
rng=rng,
scaler=scaler,

View File

@ -50,7 +50,7 @@ import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Tuple, Union
import k2
import optim
@ -798,7 +798,7 @@ def train_one_epoch(
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader,
iter_giga: Iterator,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: "GradScaler",
@ -849,7 +849,6 @@ def train_one_epoch(
# This sets the probabilities for choosing which datasets
dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0
@ -1223,6 +1222,7 @@ def run(rank, world_size, args):
# sp=sp,
# params=params,
# )
iter_giga = iter(giga_train_dl)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
@ -1247,7 +1247,7 @@ def run(rank, world_size, args):
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
giga_train_dl=giga_train_dl,
iter_giga=iter_giga,
valid_dl=valid_dl,
rng=rng,
scaler=scaler,

View File

@ -55,7 +55,7 @@ import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Tuple, Union
import k2
import optim
@ -793,7 +793,7 @@ def train_one_epoch(
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader,
iter_giga: Iterator,
valid_dl: torch.utils.data.DataLoader,
rng: random.Random,
scaler: "GradScaler",
@ -849,7 +849,6 @@ def train_one_epoch(
dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0
@ -1225,6 +1224,8 @@ def run(rank, world_size, args):
params=params,
)
iter_giga = iter(giga_train_dl)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
@ -1248,7 +1249,7 @@ def run(rank, world_size, args):
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
giga_train_dl=giga_train_dl,
iter_giga=iter_giga,
valid_dl=valid_dl,
rng=rng,
scaler=scaler,