mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify dataloader code
This commit is contained in:
parent
9367ea3646
commit
15aca1fb4a
@ -115,23 +115,14 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
|
||||
|
||||
def LmDataloader(dataset: LmDataset,
|
||||
batch_size: int,
|
||||
num_workers: int):
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
drop_last=False)
|
||||
|
||||
|
||||
|
||||
|
||||
def _test():
|
||||
l = LmDataset('files.txt')
|
||||
|
||||
d = LmDataloader(l, batch_size=5, num_workers=4)
|
||||
d = torch.utils.data.DataLoader(
|
||||
dataset=l, batch_size=5, num_workers=4, drop_last=True)
|
||||
|
||||
for batch in d:
|
||||
logging.info("batch shape: ", batch.shape)
|
||||
|
||||
@ -59,7 +59,7 @@ import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from lm_datamodule import LmDataset, LmDataloader
|
||||
from lm_datamodule import LmDataset
|
||||
from subformer import Subformer
|
||||
from scaling import ScheduledFloat
|
||||
from lhotse.utils import fix_random_seed
|
||||
@ -982,14 +982,19 @@ def run(rank, world_size, args):
|
||||
|
||||
train = LmDataset(params.train_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment,)
|
||||
train_dl = LmDataloader(train, batch_size=params.batch_size,
|
||||
num_workers=params.num_workers)
|
||||
train_dl = torch.utils.data.DataLoader(
|
||||
dataset=train,
|
||||
batch_size=params.batch_size,
|
||||
num_workers=params.num_workers,
|
||||
drop_last=True)
|
||||
|
||||
valid = LmDataset(params.valid_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment)
|
||||
valid_dl = LmDataloader(valid, batch_size=params.batch_size,
|
||||
num_workers=params.num_workers)
|
||||
|
||||
valid_dl = torch.utils.data.DataLoader(
|
||||
dataset=valid,
|
||||
batch_size=params.batch_size,
|
||||
num_workers=params.num_workers,
|
||||
drop_last=False)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16,
|
||||
init_scale=1.0)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user