mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 05:55:26 +00:00
Fix gigaspeech dataset iterator.
Previously, it was reset after every epoch, which may cause it to always use the first part of the gigaspeech dataset if you choose a small --giga-prob.
This commit is contained in:
parent
693f069de7
commit
8905ef5d59
@ -53,7 +53,7 @@ import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -770,7 +770,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
iter_giga: Iterator,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: "GradScaler",
|
||||
@ -826,7 +826,6 @@ def train_one_epoch(
|
||||
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
||||
|
||||
iter_libri = iter(train_dl)
|
||||
iter_giga = iter(giga_train_dl)
|
||||
|
||||
batch_idx = 0
|
||||
|
||||
@ -1177,6 +1176,8 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
logging.info("Skip scan_pessimistic_batches_for_oom")
|
||||
|
||||
iter_giga = iter(giga_train_dl)
|
||||
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
@ -1200,7 +1201,7 @@ def run(rank, world_size, args):
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
giga_train_dl=giga_train_dl,
|
||||
iter_giga=iter_giga,
|
||||
valid_dl=valid_dl,
|
||||
rng=rng,
|
||||
scaler=scaler,
|
||||
|
||||
@ -53,7 +53,7 @@ import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -753,7 +753,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
iter_giga: Iterator,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: "GradScaler",
|
||||
@ -806,7 +806,6 @@ def train_one_epoch(
|
||||
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
||||
|
||||
iter_libri = iter(train_dl)
|
||||
iter_giga = iter(giga_train_dl)
|
||||
|
||||
batch_idx = 0
|
||||
|
||||
@ -950,9 +949,9 @@ def filter_short_and_long_utterances(
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
# )
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
@ -965,14 +964,14 @@ def filter_short_and_long_utterances(
|
||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||
|
||||
if T < len(tokens):
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. "
|
||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||
f"Number of frames (after subsampling): {T}. "
|
||||
f"Text: {c.supervisions[0].text}. "
|
||||
f"Tokens: {tokens}. "
|
||||
f"Number of tokens: {len(tokens)}"
|
||||
)
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. "
|
||||
# f"Number of frames (before subsampling): {c.num_frames}. "
|
||||
# f"Number of frames (after subsampling): {T}. "
|
||||
# f"Text: {c.supervisions[0].text}. "
|
||||
# f"Tokens: {tokens}. "
|
||||
# f"Number of tokens: {len(tokens)}"
|
||||
# )
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -1117,6 +1116,8 @@ def run(rank, world_size, args):
|
||||
# It's time consuming to include `giga_train_dl` here
|
||||
# for dl in [train_dl, giga_train_dl]:
|
||||
for dl in [train_dl]:
|
||||
# You can skip scan_pessimistic_batches_for_oom() if you are sure
|
||||
# your selected params won't cause OOM
|
||||
if params.start_batch <= 0:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
@ -1127,6 +1128,8 @@ def run(rank, world_size, args):
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
|
||||
iter_giga = iter(giga_train_dl)
|
||||
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
@ -1149,7 +1152,7 @@ def run(rank, world_size, args):
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
giga_train_dl=giga_train_dl,
|
||||
iter_giga=iter_giga,
|
||||
valid_dl=valid_dl,
|
||||
rng=rng,
|
||||
scaler=scaler,
|
||||
|
||||
@ -50,7 +50,7 @@ import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -798,7 +798,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
iter_giga: Iterator,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: "GradScaler",
|
||||
@ -849,7 +849,6 @@ def train_one_epoch(
|
||||
# This sets the probabilities for choosing which datasets
|
||||
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
||||
iter_libri = iter(train_dl)
|
||||
iter_giga = iter(giga_train_dl)
|
||||
|
||||
batch_idx = 0
|
||||
|
||||
@ -1223,6 +1222,7 @@ def run(rank, world_size, args):
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
iter_giga = iter(giga_train_dl)
|
||||
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -1247,7 +1247,7 @@ def run(rank, world_size, args):
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
giga_train_dl=giga_train_dl,
|
||||
iter_giga=iter_giga,
|
||||
valid_dl=valid_dl,
|
||||
rng=rng,
|
||||
scaler=scaler,
|
||||
|
||||
@ -55,7 +55,7 @@ import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -793,7 +793,7 @@ def train_one_epoch(
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
giga_train_dl: torch.utils.data.DataLoader,
|
||||
iter_giga: Iterator,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
rng: random.Random,
|
||||
scaler: "GradScaler",
|
||||
@ -849,7 +849,6 @@ def train_one_epoch(
|
||||
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
||||
|
||||
iter_libri = iter(train_dl)
|
||||
iter_giga = iter(giga_train_dl)
|
||||
|
||||
batch_idx = 0
|
||||
|
||||
@ -1225,6 +1224,8 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
iter_giga = iter(giga_train_dl)
|
||||
|
||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
@ -1248,7 +1249,7 @@ def run(rank, world_size, args):
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
giga_train_dl=giga_train_dl,
|
||||
iter_giga=iter_giga,
|
||||
valid_dl=valid_dl,
|
||||
rng=rng,
|
||||
scaler=scaler,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user