This commit is contained in:
Daniel Povey 2023-06-19 04:54:39 +08:00
parent 03ad0d7910
commit 6c3ab1e706
3 changed files with 12 additions and 28 deletions

View File

@ -209,7 +209,8 @@ def main():
model.eval()
valid = LmDataset(params.valid_file_list,
bytes_per_segment=params.bytes_per_segment)
bytes_per_segment=params.bytes_per_segment,
training=False)
valid_dl = torch.utils.data.DataLoader(
dataset=valid,
batch_size=params.batch_size,

View File

@ -38,8 +38,6 @@ class LmDataset(torch.utils.data.IterableDataset):
def __init__(self,
file_list_fn: Path,
bytes_per_segment: int = 200,
world_size: int = 1,
rank: int = 0,
training: bool = True,
):
"""
@ -53,11 +51,9 @@ class LmDataset(torch.utils.data.IterableDataset):
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
e.g. a line might contain the text "64324 foo/abc.txt".
(filenames can not contain spaces).
world_size, rank: from DDP. We get the data-loader id and world-size separately.
bytes_per_segment: the number of bytes in each segment of data.
"""
self.training = training
self.skip_to_batch_idx = skip_to_batch_idx
self.files = []
self.num_bytes = []
self.bytes_per_segment = bytes_per_segment
@ -144,11 +140,11 @@ class LmDataset(torch.utils.data.IterableDataset):
b = b + b'\0' * (self.bytes_per_segment - len(b))
yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long)
def tot_tokens(self):
# Returns the total number of tokens, including padding tokens, in
# the dataset; this is for purposes of figuring out how many we
# epochs we have trained for.
return self.tot_positions
def num_tokens(self):
# Returns the total number of tokens, including padding tokens, in
# the dataset; this is for purposes of figuring out how many we
# epochs we have trained for.
return self.tot_positions
def _test():

View File

@ -459,8 +459,6 @@ def load_checkpoint_if_available(
"""
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
@ -475,22 +473,11 @@ def load_checkpoint_if_available(
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -903,7 +890,6 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
)
@ -945,8 +931,8 @@ def run(rank, world_size, args):
train = LmDataset(params.train_file_list,
bytes_per_segment=params.bytes_per_segment,
skip_to_batch_idx=getattr(params, 'cur_batch_idx', 0))
bytes_per_segment=params.bytes_per_segment)
params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress.
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
@ -959,7 +945,9 @@ def run(rank, world_size, args):
valid = LmDataset(params.valid_file_list,
bytes_per_segment=params.bytes_per_segment)
bytes_per_segment=params.bytes_per_segment,
training=False)
valid_dl = torch.utils.data.DataLoader(
dataset=valid,
batch_size=batch_size,
@ -997,7 +985,6 @@ def run(rank, world_size, args):
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
logging.info("Done!")