mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement train mode in lm_datamodule
This commit is contained in:
parent
3a71a53d8d
commit
3351402875
@ -40,6 +40,7 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
bytes_per_segment: int = 200,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
training: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize LmDataset object. Args:
|
||||
@ -48,6 +49,7 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
(filenames can not contain spaces).
|
||||
bytes_per_segment: the number of bytes in each segment of data.
|
||||
"""
|
||||
self.training = training
|
||||
self.files = []
|
||||
self.num_bytes = []
|
||||
self.bytes_per_segment = bytes_per_segment
|
||||
@ -80,7 +82,7 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
# id includes both worker (within training job) and rank of training job
|
||||
my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank
|
||||
|
||||
seed = random.randint(0, 10000) + my_id
|
||||
seed = (random.randint(0, 10000) if self.training else 0) + my_id
|
||||
# the next line is because, for some reason, when we ran with --worle-size more than 1,
|
||||
# this info message was not printed out.
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user