Implement train mode in lm_datamodule

This commit is contained in:
Daniel Povey 2023-05-23 11:08:05 +08:00
parent 3a71a53d8d
commit 3351402875

View File

@ -40,6 +40,7 @@ class LmDataset(torch.utils.data.IterableDataset):
bytes_per_segment: int = 200,
world_size: int = 1,
rank: int = 0,
training: bool = True
):
"""
Initialize LmDataset object. Args:
@ -48,6 +49,7 @@ class LmDataset(torch.utils.data.IterableDataset):
(filenames can not contain spaces).
bytes_per_segment: the number of bytes in each segment of data.
"""
self.training = training
self.files = []
self.num_bytes = []
self.bytes_per_segment = bytes_per_segment
@ -80,7 +82,7 @@ class LmDataset(torch.utils.data.IterableDataset):
# id includes both worker (within training job) and rank of training job
my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank
seed = random.randint(0, 10000) + my_id
seed = (random.randint(0, 10000) if self.training else 0) + my_id
# the next line is because, for some reason, when we ran with --worle-size more than 1,
# this info message was not printed out.
logging.getLogger().setLevel(logging.INFO)