mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 22:45:27 +00:00
Fix gigaspeech dataset iterator. (#2045)
Previously, it was reset after every epoch, which may cause it to always use the first part of the gigaspeech dataset if you choose a small --giga-prob.
This commit is contained in:
parent
693f069de7
commit
0904e490c5
@ -53,7 +53,7 @@ import random
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -770,7 +770,7 @@ def train_one_epoch(
|
|||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
giga_train_dl: torch.utils.data.DataLoader,
|
iter_giga: Iterator,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
rng: random.Random,
|
rng: random.Random,
|
||||||
scaler: "GradScaler",
|
scaler: "GradScaler",
|
||||||
@ -826,7 +826,6 @@ def train_one_epoch(
|
|||||||
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
||||||
|
|
||||||
iter_libri = iter(train_dl)
|
iter_libri = iter(train_dl)
|
||||||
iter_giga = iter(giga_train_dl)
|
|
||||||
|
|
||||||
batch_idx = 0
|
batch_idx = 0
|
||||||
|
|
||||||
@ -1177,6 +1176,8 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
logging.info("Skip scan_pessimistic_batches_for_oom")
|
logging.info("Skip scan_pessimistic_batches_for_oom")
|
||||||
|
|
||||||
|
iter_giga = iter(giga_train_dl)
|
||||||
|
|
||||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
@ -1200,7 +1201,7 @@ def run(rank, world_size, args):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
giga_train_dl=giga_train_dl,
|
iter_giga=iter_giga,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
rng=rng,
|
rng=rng,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
|
|||||||
@ -53,7 +53,7 @@ import random
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -753,7 +753,7 @@ def train_one_epoch(
|
|||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
giga_train_dl: torch.utils.data.DataLoader,
|
iter_giga: Iterator,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
rng: random.Random,
|
rng: random.Random,
|
||||||
scaler: "GradScaler",
|
scaler: "GradScaler",
|
||||||
@ -806,7 +806,6 @@ def train_one_epoch(
|
|||||||
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
||||||
|
|
||||||
iter_libri = iter(train_dl)
|
iter_libri = iter(train_dl)
|
||||||
iter_giga = iter(giga_train_dl)
|
|
||||||
|
|
||||||
batch_idx = 0
|
batch_idx = 0
|
||||||
|
|
||||||
@ -950,9 +949,9 @@ def filter_short_and_long_utterances(
|
|||||||
# an utterance duration distribution for your dataset to select
|
# an utterance duration distribution for your dataset to select
|
||||||
# the threshold
|
# the threshold
|
||||||
if c.duration < 1.0 or c.duration > 20.0:
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
logging.warning(
|
# logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||||
)
|
# )
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
# In pruned RNN-T, we require that T >= S
|
||||||
@ -965,14 +964,14 @@ def filter_short_and_long_utterances(
|
|||||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||||
|
|
||||||
if T < len(tokens):
|
if T < len(tokens):
|
||||||
logging.warning(
|
# logging.warning(
|
||||||
f"Exclude cut with ID {c.id} from training. "
|
# f"Exclude cut with ID {c.id} from training. "
|
||||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
# f"Number of frames (before subsampling): {c.num_frames}. "
|
||||||
f"Number of frames (after subsampling): {T}. "
|
# f"Number of frames (after subsampling): {T}. "
|
||||||
f"Text: {c.supervisions[0].text}. "
|
# f"Text: {c.supervisions[0].text}. "
|
||||||
f"Tokens: {tokens}. "
|
# f"Tokens: {tokens}. "
|
||||||
f"Number of tokens: {len(tokens)}"
|
# f"Number of tokens: {len(tokens)}"
|
||||||
)
|
# )
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -1117,6 +1116,8 @@ def run(rank, world_size, args):
|
|||||||
# It's time consuming to include `giga_train_dl` here
|
# It's time consuming to include `giga_train_dl` here
|
||||||
# for dl in [train_dl, giga_train_dl]:
|
# for dl in [train_dl, giga_train_dl]:
|
||||||
for dl in [train_dl]:
|
for dl in [train_dl]:
|
||||||
|
# You can skip scan_pessimistic_batches_for_oom() if you are sure
|
||||||
|
# your selected params won't cause OOM
|
||||||
if params.start_batch <= 0:
|
if params.start_batch <= 0:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
@ -1127,6 +1128,8 @@ def run(rank, world_size, args):
|
|||||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
iter_giga = iter(giga_train_dl)
|
||||||
|
|
||||||
scaler = create_grad_scaler(enabled=params.use_fp16)
|
scaler = create_grad_scaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
@ -1149,7 +1152,7 @@ def run(rank, world_size, args):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
giga_train_dl=giga_train_dl,
|
iter_giga=iter_giga,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
rng=rng,
|
rng=rng,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
|
|||||||
@ -50,7 +50,7 @@ import random
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -798,7 +798,7 @@ def train_one_epoch(
|
|||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
giga_train_dl: torch.utils.data.DataLoader,
|
iter_giga: Iterator,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
rng: random.Random,
|
rng: random.Random,
|
||||||
scaler: "GradScaler",
|
scaler: "GradScaler",
|
||||||
@ -849,7 +849,6 @@ def train_one_epoch(
|
|||||||
# This sets the probabilities for choosing which datasets
|
# This sets the probabilities for choosing which datasets
|
||||||
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
||||||
iter_libri = iter(train_dl)
|
iter_libri = iter(train_dl)
|
||||||
iter_giga = iter(giga_train_dl)
|
|
||||||
|
|
||||||
batch_idx = 0
|
batch_idx = 0
|
||||||
|
|
||||||
@ -1223,6 +1222,7 @@ def run(rank, world_size, args):
|
|||||||
# sp=sp,
|
# sp=sp,
|
||||||
# params=params,
|
# params=params,
|
||||||
# )
|
# )
|
||||||
|
iter_giga = iter(giga_train_dl)
|
||||||
|
|
||||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -1247,7 +1247,7 @@ def run(rank, world_size, args):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
giga_train_dl=giga_train_dl,
|
iter_giga=iter_giga,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
rng=rng,
|
rng=rng,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
|
|||||||
@ -55,7 +55,7 @@ import random
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -793,7 +793,7 @@ def train_one_epoch(
|
|||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
giga_train_dl: torch.utils.data.DataLoader,
|
iter_giga: Iterator,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
rng: random.Random,
|
rng: random.Random,
|
||||||
scaler: "GradScaler",
|
scaler: "GradScaler",
|
||||||
@ -849,7 +849,6 @@ def train_one_epoch(
|
|||||||
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
dl_weights = [1 - params.giga_prob, params.giga_prob]
|
||||||
|
|
||||||
iter_libri = iter(train_dl)
|
iter_libri = iter(train_dl)
|
||||||
iter_giga = iter(giga_train_dl)
|
|
||||||
|
|
||||||
batch_idx = 0
|
batch_idx = 0
|
||||||
|
|
||||||
@ -1225,6 +1224,8 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
iter_giga = iter(giga_train_dl)
|
||||||
|
|
||||||
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = create_grad_scaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
@ -1248,7 +1249,7 @@ def run(rank, world_size, args):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
giga_train_dl=giga_train_dl,
|
iter_giga=iter_giga,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
rng=rng,
|
rng=rng,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user