mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify dataloader code
This commit is contained in:
parent
9367ea3646
commit
15aca1fb4a
@ -115,23 +115,14 @@ class LmDataset(torch.utils.data.IterableDataset):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def LmDataloader(dataset: LmDataset,
|
|
||||||
batch_size: int,
|
|
||||||
num_workers: int):
|
|
||||||
|
|
||||||
return torch.utils.data.DataLoader(
|
|
||||||
dataset=dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
drop_last=False)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test():
|
def _test():
|
||||||
l = LmDataset('files.txt')
|
l = LmDataset('files.txt')
|
||||||
|
|
||||||
d = LmDataloader(l, batch_size=5, num_workers=4)
|
d = torch.utils.data.DataLoader(
|
||||||
|
dataset=l, batch_size=5, num_workers=4, drop_last=True)
|
||||||
|
|
||||||
for batch in d:
|
for batch in d:
|
||||||
logging.info("batch shape: ", batch.shape)
|
logging.info("batch shape: ", batch.shape)
|
||||||
|
|||||||
@ -59,7 +59,7 @@ import optim
|
|||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lm_datamodule import LmDataset, LmDataloader
|
from lm_datamodule import LmDataset
|
||||||
from subformer import Subformer
|
from subformer import Subformer
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -982,14 +982,19 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train = LmDataset(params.train_file_list,
|
train = LmDataset(params.train_file_list,
|
||||||
bytes_per_segment=params.bytes_per_segment,)
|
bytes_per_segment=params.bytes_per_segment,)
|
||||||
train_dl = LmDataloader(train, batch_size=params.batch_size,
|
train_dl = torch.utils.data.DataLoader(
|
||||||
num_workers=params.num_workers)
|
dataset=train,
|
||||||
|
batch_size=params.batch_size,
|
||||||
|
num_workers=params.num_workers,
|
||||||
|
drop_last=True)
|
||||||
|
|
||||||
valid = LmDataset(params.valid_file_list,
|
valid = LmDataset(params.valid_file_list,
|
||||||
bytes_per_segment=params.bytes_per_segment)
|
bytes_per_segment=params.bytes_per_segment)
|
||||||
valid_dl = LmDataloader(valid, batch_size=params.batch_size,
|
valid_dl = torch.utils.data.DataLoader(
|
||||||
num_workers=params.num_workers)
|
dataset=valid,
|
||||||
|
batch_size=params.batch_size,
|
||||||
|
num_workers=params.num_workers,
|
||||||
|
drop_last=False)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16,
|
scaler = GradScaler(enabled=params.use_fp16,
|
||||||
init_scale=1.0)
|
init_scale=1.0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user