From 15aca1fb4ab3cc96efee457caf8fffe5f585ed35 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Thu, 18 May 2023 13:55:52 +0800 Subject: [PATCH] Simplify dataloader code --- egs/libriheavy/LM/zipformer1/lm_datamodule.py | 13 ++----------- egs/libriheavy/LM/zipformer1/train.py | 17 +++++++++++------ 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/lm_datamodule.py b/egs/libriheavy/LM/zipformer1/lm_datamodule.py index 0ef0ff98b..1f6a70828 100644 --- a/egs/libriheavy/LM/zipformer1/lm_datamodule.py +++ b/egs/libriheavy/LM/zipformer1/lm_datamodule.py @@ -115,23 +115,14 @@ class LmDataset(torch.utils.data.IterableDataset): -def LmDataloader(dataset: LmDataset, - batch_size: int, - num_workers: int): - - return torch.utils.data.DataLoader( - dataset=dataset, - batch_size=batch_size, - num_workers=num_workers, - drop_last=False) - def _test(): l = LmDataset('files.txt') - d = LmDataloader(l, batch_size=5, num_workers=4) + d = torch.utils.data.DataLoader( + dataset=l, batch_size=5, num_workers=4, drop_last=True) for batch in d: logging.info("batch shape: ", batch.shape) diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index b825d6b85..a04e398e9 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -59,7 +59,7 @@ import optim import torch import torch.multiprocessing as mp import torch.nn as nn -from lm_datamodule import LmDataset, LmDataloader +from lm_datamodule import LmDataset from subformer import Subformer from scaling import ScheduledFloat from lhotse.utils import fix_random_seed @@ -982,14 +982,19 @@ def run(rank, world_size, args): train = LmDataset(params.train_file_list, bytes_per_segment=params.bytes_per_segment,) - train_dl = LmDataloader(train, batch_size=params.batch_size, - num_workers=params.num_workers) + train_dl = torch.utils.data.DataLoader( + dataset=train, + batch_size=params.batch_size, + num_workers=params.num_workers, + drop_last=True) valid = LmDataset(params.valid_file_list, bytes_per_segment=params.bytes_per_segment) - valid_dl = LmDataloader(valid, batch_size=params.batch_size, - num_workers=params.num_workers) - + valid_dl = torch.utils.data.DataLoader( + dataset=valid, + batch_size=params.batch_size, + num_workers=params.num_workers, + drop_last=False) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)