From 6c3ab1e706d67e6568c925611abd7e89b583d800 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 19 Jun 2023 04:54:39 +0800 Subject: [PATCH] Fixes --- egs/libriheavy/LM/zipformer1/evaluate.py | 3 ++- egs/libriheavy/LM/zipformer1/lm_datamodule.py | 14 ++++------- egs/libriheavy/LM/zipformer1/train.py | 23 ++++--------------- 3 files changed, 12 insertions(+), 28 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/evaluate.py b/egs/libriheavy/LM/zipformer1/evaluate.py index a65bf6d85..83ccfe922 100644 --- a/egs/libriheavy/LM/zipformer1/evaluate.py +++ b/egs/libriheavy/LM/zipformer1/evaluate.py @@ -209,7 +209,8 @@ def main(): model.eval() valid = LmDataset(params.valid_file_list, - bytes_per_segment=params.bytes_per_segment) + bytes_per_segment=params.bytes_per_segment, + training=False) valid_dl = torch.utils.data.DataLoader( dataset=valid, batch_size=params.batch_size, diff --git a/egs/libriheavy/LM/zipformer1/lm_datamodule.py b/egs/libriheavy/LM/zipformer1/lm_datamodule.py index f65b56892..fb6fa7177 100644 --- a/egs/libriheavy/LM/zipformer1/lm_datamodule.py +++ b/egs/libriheavy/LM/zipformer1/lm_datamodule.py @@ -38,8 +38,6 @@ class LmDataset(torch.utils.data.IterableDataset): def __init__(self, file_list_fn: Path, bytes_per_segment: int = 200, - world_size: int = 1, - rank: int = 0, training: bool = True, ): """ @@ -53,11 +51,9 @@ class LmDataset(torch.utils.data.IterableDataset): file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename. e.g. a line might contain the text "64324 foo/abc.txt". (filenames can not contain spaces). - world_size, rank: from DDP. We get the data-loader id and world-size separately. bytes_per_segment: the number of bytes in each segment of data. """ self.training = training - self.skip_to_batch_idx = skip_to_batch_idx self.files = [] self.num_bytes = [] self.bytes_per_segment = bytes_per_segment @@ -144,11 +140,11 @@ class LmDataset(torch.utils.data.IterableDataset): b = b + b'\0' * (self.bytes_per_segment - len(b)) yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long) - def tot_tokens(self): - # Returns the total number of tokens, including padding tokens, in - # the dataset; this is for purposes of figuring out how many we - # epochs we have trained for. - return self.tot_positions + def num_tokens(self): + # Returns the total number of tokens, including padding tokens, in + # the dataset; this is for purposes of figuring out how many we + # epochs we have trained for. + return self.tot_positions def _test(): diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 98ae03371..b7af8f331 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -459,8 +459,6 @@ def load_checkpoint_if_available( """ if params.start_batch > 0: filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -475,22 +473,11 @@ def load_checkpoint_if_available( ) keys = [ - "best_train_epoch", - "best_valid_epoch", "batch_idx_train", - "best_train_loss", - "best_valid_loss", ] for k in keys: params[k] = saved_params[k] - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - return saved_params @@ -903,7 +890,6 @@ def run(rank, world_size, args): # model_avg is only used with rank 0 model_avg = copy.deepcopy(model).to(torch.float64) - assert params.start_epoch > 0, params.start_epoch checkpoints = load_checkpoint_if_available( params=params, model=model, model_avg=model_avg ) @@ -945,8 +931,8 @@ def run(rank, world_size, args): train = LmDataset(params.train_file_list, - bytes_per_segment=params.bytes_per_segment, - skip_to_batch_idx=getattr(params, 'cur_batch_idx', 0)) + bytes_per_segment=params.bytes_per_segment) + params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress. batch_size = params.batch_size // (6 if params.print_diagnostics else 1) @@ -959,7 +945,9 @@ def run(rank, world_size, args): valid = LmDataset(params.valid_file_list, - bytes_per_segment=params.bytes_per_segment) + bytes_per_segment=params.bytes_per_segment, + training=False) + valid_dl = torch.utils.data.DataLoader( dataset=valid, batch_size=batch_size, @@ -997,7 +985,6 @@ def run(rank, world_size, args): if params.print_diagnostics: diagnostic.print_diagnostics() - break logging.info("Done!")