mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement train mode in lm_datamodule
This commit is contained in:
parent
3a71a53d8d
commit
3351402875
@ -40,6 +40,7 @@ class LmDataset(torch.utils.data.IterableDataset):
|
|||||||
bytes_per_segment: int = 200,
|
bytes_per_segment: int = 200,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
|
training: bool = True
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize LmDataset object. Args:
|
Initialize LmDataset object. Args:
|
||||||
@ -48,6 +49,7 @@ class LmDataset(torch.utils.data.IterableDataset):
|
|||||||
(filenames can not contain spaces).
|
(filenames can not contain spaces).
|
||||||
bytes_per_segment: the number of bytes in each segment of data.
|
bytes_per_segment: the number of bytes in each segment of data.
|
||||||
"""
|
"""
|
||||||
|
self.training = training
|
||||||
self.files = []
|
self.files = []
|
||||||
self.num_bytes = []
|
self.num_bytes = []
|
||||||
self.bytes_per_segment = bytes_per_segment
|
self.bytes_per_segment = bytes_per_segment
|
||||||
@ -80,7 +82,7 @@ class LmDataset(torch.utils.data.IterableDataset):
|
|||||||
# id includes both worker (within training job) and rank of training job
|
# id includes both worker (within training job) and rank of training job
|
||||||
my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank
|
my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank
|
||||||
|
|
||||||
seed = random.randint(0, 10000) + my_id
|
seed = (random.randint(0, 10000) if self.training else 0) + my_id
|
||||||
# the next line is because, for some reason, when we ran with --worle-size more than 1,
|
# the next line is because, for some reason, when we ran with --worle-size more than 1,
|
||||||
# this info message was not printed out.
|
# this info message was not printed out.
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user