mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Fixes
This commit is contained in:
parent
03ad0d7910
commit
6c3ab1e706
@ -209,7 +209,8 @@ def main():
|
||||
model.eval()
|
||||
|
||||
valid = LmDataset(params.valid_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment)
|
||||
bytes_per_segment=params.bytes_per_segment,
|
||||
training=False)
|
||||
valid_dl = torch.utils.data.DataLoader(
|
||||
dataset=valid,
|
||||
batch_size=params.batch_size,
|
||||
|
@ -38,8 +38,6 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
def __init__(self,
|
||||
file_list_fn: Path,
|
||||
bytes_per_segment: int = 200,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
training: bool = True,
|
||||
):
|
||||
"""
|
||||
@ -53,11 +51,9 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
|
||||
e.g. a line might contain the text "64324 foo/abc.txt".
|
||||
(filenames can not contain spaces).
|
||||
world_size, rank: from DDP. We get the data-loader id and world-size separately.
|
||||
bytes_per_segment: the number of bytes in each segment of data.
|
||||
"""
|
||||
self.training = training
|
||||
self.skip_to_batch_idx = skip_to_batch_idx
|
||||
self.files = []
|
||||
self.num_bytes = []
|
||||
self.bytes_per_segment = bytes_per_segment
|
||||
@ -144,11 +140,11 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
b = b + b'\0' * (self.bytes_per_segment - len(b))
|
||||
yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long)
|
||||
|
||||
def tot_tokens(self):
|
||||
# Returns the total number of tokens, including padding tokens, in
|
||||
# the dataset; this is for purposes of figuring out how many we
|
||||
# epochs we have trained for.
|
||||
return self.tot_positions
|
||||
def num_tokens(self):
|
||||
# Returns the total number of tokens, including padding tokens, in
|
||||
# the dataset; this is for purposes of figuring out how many we
|
||||
# epochs we have trained for.
|
||||
return self.tot_positions
|
||||
|
||||
|
||||
def _test():
|
||||
|
@ -459,8 +459,6 @@ def load_checkpoint_if_available(
|
||||
"""
|
||||
if params.start_batch > 0:
|
||||
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -475,22 +473,11 @@ def load_checkpoint_if_available(
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
if params.start_batch > 0:
|
||||
if "cur_epoch" in saved_params:
|
||||
params["start_epoch"] = saved_params["cur_epoch"]
|
||||
|
||||
if "cur_batch_idx" in saved_params:
|
||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
@ -903,7 +890,6 @@ def run(rank, world_size, args):
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
@ -945,8 +931,8 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
train = LmDataset(params.train_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment,
|
||||
skip_to_batch_idx=getattr(params, 'cur_batch_idx', 0))
|
||||
bytes_per_segment=params.bytes_per_segment)
|
||||
|
||||
params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress.
|
||||
|
||||
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
|
||||
@ -959,7 +945,9 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
valid = LmDataset(params.valid_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment)
|
||||
bytes_per_segment=params.bytes_per_segment,
|
||||
training=False)
|
||||
|
||||
valid_dl = torch.utils.data.DataLoader(
|
||||
dataset=valid,
|
||||
batch_size=batch_size,
|
||||
@ -997,7 +985,6 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user