diff --git a/egs/libriheavy/LM/zipformer1/lm_datamodule.py b/egs/libriheavy/LM/zipformer1/lm_datamodule.py index 8b5236c2e..25ffc6843 100644 --- a/egs/libriheavy/LM/zipformer1/lm_datamodule.py +++ b/egs/libriheavy/LM/zipformer1/lm_datamodule.py @@ -40,7 +40,8 @@ class LmDataset(torch.utils.data.IterableDataset): bytes_per_segment: int = 200, world_size: int = 1, rank: int = 0, - training: bool = True + training: bool = True, + skip_to_batch_idx: int = 0, ): """ Initialize LmDataset object. Args: @@ -48,8 +49,10 @@ class LmDataset(torch.utils.data.IterableDataset): e.g. a line might contain the text "64324 foo/abc.txt". (filenames can not contain spaces). bytes_per_segment: the number of bytes in each segment of data. + skip_to_batch_idx: if provided, the first time we iterate we will skip this many batches. """ self.training = training + self.skip_to_batch_idx = skip_to_batch_idx self.files = [] self.num_bytes = [] self.bytes_per_segment = bytes_per_segment @@ -88,6 +91,12 @@ class LmDataset(torch.utils.data.IterableDataset): logging.getLogger().setLevel(logging.INFO) logging.info(f"my_id={my_id}, seed={seed}, num_segments={self.num_segments}") rng = np.random.default_rng(seed=seed) + + skip_to_batch_idx = self.skip_to_batch_idx + if skip_to_batch_idx != 0: + logging.info(f"skip-to-batch-idx={skip_to_batch_idx}") + self.skip_to_batch_idx = 0 # so only the 1st time we iterate, we respect this. + for n in range(self.num_segments): # np.random.multinomial / np.random.Generator.multinomial has an interface # where it gives counts of different categories, instead of the chosen category, @@ -97,6 +106,9 @@ class LmDataset(torch.utils.data.IterableDataset): file_idx, = np.nonzero(rng.multinomial(1, self.probs)) file_idx, = file_idx + if n < skip_to_batch_idx: + continue + fn = self.files[file_idx] num_bytes = self.num_bytes[file_idx] @@ -139,5 +151,5 @@ if __name__ == '__main__': # cd libriheavy/LM # find /ceph-data3/xiaoyu/librilight_text/output_text_large_cleaned -name text.txt -exec stat --printf='%s ' {} \; -print > files.txt -# head -n 2 files.txt > valid.txt -# tail -n +3 files.txt > train.txt +# head -n 4 files.txt > valid.txt +# tail -n +5 files.txt > train.txt diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 84f94d99e..8841938a6 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -762,9 +762,6 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx params.batch_idx_train += 1 @@ -991,18 +988,23 @@ def run(rank, world_size, args): train = LmDataset(params.train_file_list, - bytes_per_segment=params.bytes_per_segment,) + bytes_per_segment=params.bytes_per_segment, + skip_to_batch_idx=getattr(params, 'cur_batch_idx', 0)) + + batch_size = params.batch_size // (6 if params.print_diagnostics else 1) + train_dl = torch.utils.data.DataLoader( dataset=train, - batch_size=params.batch_size, + batch_size=batch_size, num_workers=params.num_workers, drop_last=True) + valid = LmDataset(params.valid_file_list, bytes_per_segment=params.bytes_per_segment) valid_dl = torch.utils.data.DataLoader( dataset=valid, - batch_size=params.batch_size, + batch_size=batch_size, num_workers=params.num_workers, drop_last=False) @@ -1017,7 +1019,7 @@ def run(rank, world_size, args): # to let it know how many tokens we have processed so far, and have a # soft-cutoff lr_tokens measured in tokens. # scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1 + params.start_batch) + fix_random_seed(params.seed + epoch) # the above will affect random seeds in the dataloaders. if tb_writer is not None: