Fix gigaspeech dataset iterator. (#2045)

Previously, it was reset after every epoch, which may cause it to
always use the first part of the gigaspeech dataset if you choose
a small --giga-prob.
This commit is contained in:
Fangjun Kuang 2025-11-28 11:42:20 +08:00 committed by GitHub
parent 693f069de7
commit 0904e490c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 32 additions and 27 deletions

View File

@ -53,7 +53,7 @@ import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Iterator, Optional, Tuple, Union
import k2 import k2
import optim import optim
@ -770,7 +770,7 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader, iter_giga: Iterator,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random, rng: random.Random,
scaler: "GradScaler", scaler: "GradScaler",
@ -826,7 +826,6 @@ def train_one_epoch(
dl_weights = [1 - params.giga_prob, params.giga_prob] dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl) iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0 batch_idx = 0
@ -1177,6 +1176,8 @@ def run(rank, world_size, args):
else: else:
logging.info("Skip scan_pessimistic_batches_for_oom") logging.info("Skip scan_pessimistic_batches_for_oom")
iter_giga = iter(giga_train_dl)
scaler = create_grad_scaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
@ -1200,7 +1201,7 @@ def run(rank, world_size, args):
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
giga_train_dl=giga_train_dl, iter_giga=iter_giga,
valid_dl=valid_dl, valid_dl=valid_dl,
rng=rng, rng=rng,
scaler=scaler, scaler=scaler,

View File

@ -53,7 +53,7 @@ import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Iterator, Optional, Tuple, Union
import k2 import k2
import optim import optim
@ -753,7 +753,7 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader, iter_giga: Iterator,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random, rng: random.Random,
scaler: "GradScaler", scaler: "GradScaler",
@ -806,7 +806,6 @@ def train_one_epoch(
dl_weights = [1 - params.giga_prob, params.giga_prob] dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl) iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0 batch_idx = 0
@ -950,9 +949,9 @@ def filter_short_and_long_utterances(
# an utterance duration distribution for your dataset to select # an utterance duration distribution for your dataset to select
# the threshold # the threshold
if c.duration < 1.0 or c.duration > 20.0: if c.duration < 1.0 or c.duration > 20.0:
logging.warning( # logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
) # )
return False return False
# In pruned RNN-T, we require that T >= S # In pruned RNN-T, we require that T >= S
@ -965,14 +964,14 @@ def filter_short_and_long_utterances(
tokens = sp.encode(c.supervisions[0].text, out_type=str) tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens): if T < len(tokens):
logging.warning( # logging.warning(
f"Exclude cut with ID {c.id} from training. " # f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. " # f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. " # f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. " # f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. " # f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}" # f"Number of tokens: {len(tokens)}"
) # )
return False return False
return True return True
@ -1117,6 +1116,8 @@ def run(rank, world_size, args):
# It's time consuming to include `giga_train_dl` here # It's time consuming to include `giga_train_dl` here
# for dl in [train_dl, giga_train_dl]: # for dl in [train_dl, giga_train_dl]:
for dl in [train_dl]: for dl in [train_dl]:
# You can skip scan_pessimistic_batches_for_oom() if you are sure
# your selected params won't cause OOM
if params.start_batch <= 0: if params.start_batch <= 0:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
@ -1127,6 +1128,8 @@ def run(rank, world_size, args):
warmup=0.0 if params.start_epoch == 0 else 1.0, warmup=0.0 if params.start_epoch == 0 else 1.0,
) )
iter_giga = iter(giga_train_dl)
scaler = create_grad_scaler(enabled=params.use_fp16) scaler = create_grad_scaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
@ -1149,7 +1152,7 @@ def run(rank, world_size, args):
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
giga_train_dl=giga_train_dl, iter_giga=iter_giga,
valid_dl=valid_dl, valid_dl=valid_dl,
rng=rng, rng=rng,
scaler=scaler, scaler=scaler,

View File

@ -50,7 +50,7 @@ import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Iterator, Optional, Tuple, Union
import k2 import k2
import optim import optim
@ -798,7 +798,7 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader, iter_giga: Iterator,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random, rng: random.Random,
scaler: "GradScaler", scaler: "GradScaler",
@ -849,7 +849,6 @@ def train_one_epoch(
# This sets the probabilities for choosing which datasets # This sets the probabilities for choosing which datasets
dl_weights = [1 - params.giga_prob, params.giga_prob] dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl) iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0 batch_idx = 0
@ -1223,6 +1222,7 @@ def run(rank, world_size, args):
# sp=sp, # sp=sp,
# params=params, # params=params,
# ) # )
iter_giga = iter(giga_train_dl)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
@ -1247,7 +1247,7 @@ def run(rank, world_size, args):
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
giga_train_dl=giga_train_dl, iter_giga=iter_giga,
valid_dl=valid_dl, valid_dl=valid_dl,
rng=rng, rng=rng,
scaler=scaler, scaler=scaler,

View File

@ -55,7 +55,7 @@ import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Iterator, Optional, Tuple, Union
import k2 import k2
import optim import optim
@ -793,7 +793,7 @@ def train_one_epoch(
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
giga_train_dl: torch.utils.data.DataLoader, iter_giga: Iterator,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
rng: random.Random, rng: random.Random,
scaler: "GradScaler", scaler: "GradScaler",
@ -849,7 +849,6 @@ def train_one_epoch(
dl_weights = [1 - params.giga_prob, params.giga_prob] dl_weights = [1 - params.giga_prob, params.giga_prob]
iter_libri = iter(train_dl) iter_libri = iter(train_dl)
iter_giga = iter(giga_train_dl)
batch_idx = 0 batch_idx = 0
@ -1225,6 +1224,8 @@ def run(rank, world_size, args):
params=params, params=params,
) )
iter_giga = iter(giga_train_dl)
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0) scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
@ -1248,7 +1249,7 @@ def run(rank, world_size, args):
scheduler=scheduler, scheduler=scheduler,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
giga_train_dl=giga_train_dl, iter_giga=iter_giga,
valid_dl=valid_dl, valid_dl=valid_dl,
rng=rng, rng=rng,
scaler=scaler, scaler=scaler,