Set the seed for dataloader. (#282)

Also, suppress torch warnings about division by truncation.
This commit is contained in:
Fangjun Kuang 2022-03-31 16:48:46 +08:00 committed by GitHub
parent fc40bfea82
commit 9a11808ed3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 7 deletions

View File

@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple
@ -496,7 +497,11 @@ def compute_loss(
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()

View File

@ -23,6 +23,7 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
@ -34,6 +35,7 @@ from lhotse.dataset import (
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
@ -301,12 +303,20 @@ class LibriSpeechAsrDataModule:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
def worker_init_fn(worker_id: int):
fix_random_seed(seed + worker_id)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
@ -393,7 +394,11 @@ def compute_loss(
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()

View File

@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2"
import argparse
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
@ -397,7 +398,11 @@ def compute_loss(
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()

View File

@ -109,8 +109,11 @@ class Conformer(Transformer):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
@ -419,7 +420,11 @@ def compute_loss(
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()

View File

@ -22,6 +22,7 @@ import logging
from pathlib import Path
from typing import Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import (
BucketingSampler,
@ -34,6 +35,7 @@ from lhotse.dataset.input_strategies import (
OnTheFlyFeatures,
PrecomputedFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
@ -253,12 +255,21 @@ class AsrDataModule:
)
logging.info("About to create train dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
def worker_init_fn(worker_id: int):
fix_random_seed(seed + worker_id)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl

View File

@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import logging
import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
@ -466,7 +467,11 @@ def compute_loss(
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()