diff --git a/egs/libriheavy/LM/zipformer1/lm_datamodule.py b/egs/libriheavy/LM/zipformer1/lm_datamodule.py index 1f6a70828..8b5236c2e 100644 --- a/egs/libriheavy/LM/zipformer1/lm_datamodule.py +++ b/egs/libriheavy/LM/zipformer1/lm_datamodule.py @@ -40,6 +40,7 @@ class LmDataset(torch.utils.data.IterableDataset): bytes_per_segment: int = 200, world_size: int = 1, rank: int = 0, + training: bool = True ): """ Initialize LmDataset object. Args: @@ -48,6 +49,7 @@ class LmDataset(torch.utils.data.IterableDataset): (filenames can not contain spaces). bytes_per_segment: the number of bytes in each segment of data. """ + self.training = training self.files = [] self.num_bytes = [] self.bytes_per_segment = bytes_per_segment @@ -80,7 +82,7 @@ class LmDataset(torch.utils.data.IterableDataset): # id includes both worker (within training job) and rank of training job my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank - seed = random.randint(0, 10000) + my_id + seed = (random.randint(0, 10000) if self.training else 0) + my_id # the next line is because, for some reason, when we ran with --worle-size more than 1, # this info message was not printed out. logging.getLogger().setLevel(logging.INFO)