Simplify dataloader code

This commit is contained in:
Daniel Povey 2023-05-18 13:55:52 +08:00
parent 9367ea3646
commit 15aca1fb4a
2 changed files with 13 additions and 17 deletions

View File

@ -115,23 +115,14 @@ class LmDataset(torch.utils.data.IterableDataset):
def LmDataloader(dataset: LmDataset,
batch_size: int,
num_workers: int):
return torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False)
def _test(): def _test():
l = LmDataset('files.txt') l = LmDataset('files.txt')
d = LmDataloader(l, batch_size=5, num_workers=4) d = torch.utils.data.DataLoader(
dataset=l, batch_size=5, num_workers=4, drop_last=True)
for batch in d: for batch in d:
logging.info("batch shape: ", batch.shape) logging.info("batch shape: ", batch.shape)

View File

@ -59,7 +59,7 @@ import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from lm_datamodule import LmDataset, LmDataloader from lm_datamodule import LmDataset
from subformer import Subformer from subformer import Subformer
from scaling import ScheduledFloat from scaling import ScheduledFloat
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -982,14 +982,19 @@ def run(rank, world_size, args):
train = LmDataset(params.train_file_list, train = LmDataset(params.train_file_list,
bytes_per_segment=params.bytes_per_segment,) bytes_per_segment=params.bytes_per_segment,)
train_dl = LmDataloader(train, batch_size=params.batch_size, train_dl = torch.utils.data.DataLoader(
num_workers=params.num_workers) dataset=train,
batch_size=params.batch_size,
num_workers=params.num_workers,
drop_last=True)
valid = LmDataset(params.valid_file_list, valid = LmDataset(params.valid_file_list,
bytes_per_segment=params.bytes_per_segment) bytes_per_segment=params.bytes_per_segment)
valid_dl = LmDataloader(valid, batch_size=params.batch_size, valid_dl = torch.utils.data.DataLoader(
num_workers=params.num_workers) dataset=valid,
batch_size=params.batch_size,
num_workers=params.num_workers,
drop_last=False)
scaler = GradScaler(enabled=params.use_fp16, scaler = GradScaler(enabled=params.use_fp16,
init_scale=1.0) init_scale=1.0)