This commit is contained in:
Daniel Povey 2023-06-19 04:54:39 +08:00
parent 03ad0d7910
commit 6c3ab1e706
3 changed files with 12 additions and 28 deletions

View File

@ -209,7 +209,8 @@ def main():
model.eval() model.eval()
valid = LmDataset(params.valid_file_list, valid = LmDataset(params.valid_file_list,
bytes_per_segment=params.bytes_per_segment) bytes_per_segment=params.bytes_per_segment,
training=False)
valid_dl = torch.utils.data.DataLoader( valid_dl = torch.utils.data.DataLoader(
dataset=valid, dataset=valid,
batch_size=params.batch_size, batch_size=params.batch_size,

View File

@ -38,8 +38,6 @@ class LmDataset(torch.utils.data.IterableDataset):
def __init__(self, def __init__(self,
file_list_fn: Path, file_list_fn: Path,
bytes_per_segment: int = 200, bytes_per_segment: int = 200,
world_size: int = 1,
rank: int = 0,
training: bool = True, training: bool = True,
): ):
""" """
@ -53,11 +51,9 @@ class LmDataset(torch.utils.data.IterableDataset):
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename. file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
e.g. a line might contain the text "64324 foo/abc.txt". e.g. a line might contain the text "64324 foo/abc.txt".
(filenames can not contain spaces). (filenames can not contain spaces).
world_size, rank: from DDP. We get the data-loader id and world-size separately.
bytes_per_segment: the number of bytes in each segment of data. bytes_per_segment: the number of bytes in each segment of data.
""" """
self.training = training self.training = training
self.skip_to_batch_idx = skip_to_batch_idx
self.files = [] self.files = []
self.num_bytes = [] self.num_bytes = []
self.bytes_per_segment = bytes_per_segment self.bytes_per_segment = bytes_per_segment
@ -144,7 +140,7 @@ class LmDataset(torch.utils.data.IterableDataset):
b = b + b'\0' * (self.bytes_per_segment - len(b)) b = b + b'\0' * (self.bytes_per_segment - len(b))
yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long) yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long)
def tot_tokens(self): def num_tokens(self):
# Returns the total number of tokens, including padding tokens, in # Returns the total number of tokens, including padding tokens, in
# the dataset; this is for purposes of figuring out how many we # the dataset; this is for purposes of figuring out how many we
# epochs we have trained for. # epochs we have trained for.

View File

@ -459,8 +459,6 @@ def load_checkpoint_if_available(
""" """
if params.start_batch > 0: if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else: else:
return None return None
@ -475,22 +473,11 @@ def load_checkpoint_if_available(
) )
keys = [ keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train", "batch_idx_train",
"best_train_loss",
"best_valid_loss",
] ]
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -903,7 +890,6 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0 # model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64) model_avg = copy.deepcopy(model).to(torch.float64)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available( checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg params=params, model=model, model_avg=model_avg
) )
@ -945,8 +931,8 @@ def run(rank, world_size, args):
train = LmDataset(params.train_file_list, train = LmDataset(params.train_file_list,
bytes_per_segment=params.bytes_per_segment, bytes_per_segment=params.bytes_per_segment)
skip_to_batch_idx=getattr(params, 'cur_batch_idx', 0))
params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress. params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress.
batch_size = params.batch_size // (6 if params.print_diagnostics else 1) batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
@ -959,7 +945,9 @@ def run(rank, world_size, args):
valid = LmDataset(params.valid_file_list, valid = LmDataset(params.valid_file_list,
bytes_per_segment=params.bytes_per_segment) bytes_per_segment=params.bytes_per_segment,
training=False)
valid_dl = torch.utils.data.DataLoader( valid_dl = torch.utils.data.DataLoader(
dataset=valid, dataset=valid,
batch_size=batch_size, batch_size=batch_size,
@ -997,7 +985,6 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
diagnostic.print_diagnostics() diagnostic.print_diagnostics()
break
logging.info("Done!") logging.info("Done!")