mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Remove concept of epochs from training subformer for language modeling;
revert dimensions to how they were in zlm53.
This commit is contained in:
parent
c7e8a7349d
commit
03ad0d7910
@ -43,20 +43,11 @@ def get_parser():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
help="""If positive, it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
@ -68,18 +59,15 @@ def get_parser():
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
"'--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
help="Whether to load averaged model. If True, it would decode "
|
||||
"with the averaged model over this many checkpoints."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -150,10 +138,8 @@ def main():
|
||||
|
||||
params.res_dir = params.exp_dir / "log-evaluation"
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
assert params.iter > 0
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
@ -173,81 +159,51 @@ def main():
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
@ -41,15 +41,20 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
training: bool = True,
|
||||
skip_to_batch_idx: int = 0,
|
||||
):
|
||||
"""
|
||||
Initialize LmDataset object. Args:
|
||||
Initialize LmDataset object. This keeps no state, it just gives you a totally random
|
||||
segment each time. The training files are just viewed as sequences of bytes, from which
|
||||
we select chunks of a fixed size. In training mode we just loop infinitely, and let
|
||||
the training code decide when to stop based on the count of tokens. In test mode
|
||||
we loop so that we see each byte about once.
|
||||
|
||||
Args:
|
||||
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
|
||||
e.g. a line might contain the text "64324 foo/abc.txt".
|
||||
(filenames can not contain spaces).
|
||||
world_size, rank: from DDP. We get the data-loader id and world-size separately.
|
||||
bytes_per_segment: the number of bytes in each segment of data.
|
||||
skip_to_batch_idx: if provided, the first time we iterate we will skip this many batches.
|
||||
"""
|
||||
self.training = training
|
||||
self.skip_to_batch_idx = skip_to_batch_idx
|
||||
@ -66,18 +71,25 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
fn = line[len(num_bytes) + 1:] # this works even if fn has spaces in
|
||||
self.files.append(fn)
|
||||
self.num_bytes.append(int(num_bytes))
|
||||
tot_bytes = sum(self.num_bytes)
|
||||
N = len(self.num_bytes)
|
||||
self.probs = np.array([ x / tot_bytes for x in self.num_bytes ])
|
||||
|
||||
# For purposes of choosing the possible start-positions of a segment: we
|
||||
# need to pad on the left by bytes_per_segment - 1. This is part of a
|
||||
# scheme to ensure that each byte in each training file is chosen with
|
||||
# equal probability, while also choosing different shifts of the data
|
||||
# with equal probability. We end up padding with zeroes if we
|
||||
# are outside the file either on the left or the right.
|
||||
pad = self.bytes_per_segment - 1
|
||||
tot_positions = sum([ x + pad for x in self.num_bytes])
|
||||
self.probs = np.array([ (x + pad) / tot_positions for x in self.num_bytes ])
|
||||
self.tot_positions = tot_positions
|
||||
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
num_workers = (1 if worker_info is None else worker_info.num_workers)
|
||||
|
||||
# world_size is for ddp training, num_workers for data-loader worker threads.
|
||||
# num_workers for data-loader worker threads; world_size is for ddp training.
|
||||
tot_workers = num_workers * get_world_size()
|
||||
|
||||
|
||||
self.num_segments = tot_bytes // (bytes_per_segment * tot_workers)
|
||||
self.num_segments = float('inf') if training else 1 + tot_positions // (bytes_per_segment * tot_workers)
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
@ -85,19 +97,22 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
# id includes both worker (within training job) and rank of training job
|
||||
my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank
|
||||
|
||||
# note: the seed depends on the current random state, which will be different
|
||||
# depending on the DDP worker id and also depending which batch we restarted
|
||||
# training on. This does not guarantee that you get repeatability if you
|
||||
# restart training, but it does ensure you don't see exactly repeated data.
|
||||
seed = (random.randint(0, 10000) if self.training else 0) + my_id
|
||||
# the next line is because, for some reason, when we ran with --worle-size more than 1,
|
||||
# this info message was not printed out.
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
logging.info(f"my_id={my_id}, seed={seed}, num_segments={self.num_segments}")
|
||||
# use numpy's generator, not random's, because we need np.random.multinomial.
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
|
||||
skip_to_batch_idx = self.skip_to_batch_idx
|
||||
if skip_to_batch_idx != 0:
|
||||
logging.info(f"skip-to-batch-idx={skip_to_batch_idx}")
|
||||
self.skip_to_batch_idx = 0 # so only the 1st time we iterate, we respect this.
|
||||
n = 0
|
||||
while n < self.num_segments: # if self.num_segments is infinity, just keep going.
|
||||
n += 1
|
||||
|
||||
for n in range(self.num_segments):
|
||||
# np.random.multinomial / np.random.Generator.multinomial has an interface
|
||||
# where it gives counts of different categories, instead of the chosen category,
|
||||
# so we need to use np.nonzero to get the chosen category (i.e. the file index)
|
||||
@ -106,30 +121,34 @@ class LmDataset(torch.utils.data.IterableDataset):
|
||||
file_idx, = np.nonzero(rng.multinomial(1, self.probs))
|
||||
file_idx, = file_idx
|
||||
|
||||
if n < skip_to_batch_idx:
|
||||
continue
|
||||
|
||||
fn = self.files[file_idx]
|
||||
num_bytes = self.num_bytes[file_idx]
|
||||
|
||||
# begin_pos, end_pos are the begin,end of a range from which we'll pick
|
||||
# randomly, for where the start of the segment might be.
|
||||
begin_pos = 0
|
||||
end_pos = max(1, num_bytes - self.bytes_per_segment)
|
||||
# randomly, for where the start of the segment might be. We only
|
||||
# guarantee that a segment should contain at most one byte of data;
|
||||
# this helps ensure that each byte is chosen with the exact same probability,
|
||||
# which is easier for analysis.
|
||||
begin_pos = - (self.bytes_per_segment - 1)
|
||||
end_pos = max(1, num_bytes - 1)
|
||||
|
||||
begin, = rng.integers(low=begin_pos, high=end_pos, size=1)
|
||||
|
||||
with open(fn, "rb") as f:
|
||||
f.seek(begin)
|
||||
b = f.read(self.bytes_per_segment) # b is bytes object
|
||||
read_size = len(b)
|
||||
if read_size < self.bytes_per_segment:
|
||||
b = b + b'\0' * (self.bytes_per_segment - read_size)
|
||||
if begin >= 0:
|
||||
f.seek(begin)
|
||||
b = f.read(self.bytes_per_segment) # b is bytes object
|
||||
else:
|
||||
b = b'\0' * -begin + f.read(self.bytes_per_segment + begin)
|
||||
if len(b) < self.bytes_per_segment:
|
||||
b = b + b'\0' * (self.bytes_per_segment - len(b))
|
||||
yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long)
|
||||
|
||||
|
||||
|
||||
|
||||
def tot_tokens(self):
|
||||
# Returns the total number of tokens, including padding tokens, in
|
||||
# the dataset; this is for purposes of figuring out how many we
|
||||
# epochs we have trained for.
|
||||
return self.tot_positions
|
||||
|
||||
|
||||
def _test():
|
||||
|
@ -22,24 +22,11 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
./zipformer1/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
--exp-dir zipformer1/exp \
|
||||
--use-fp16 True
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
"""
|
||||
|
||||
@ -143,7 +130,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
default="384,512,512,768,512,512,384",
|
||||
default="256,384,512,768,512,384,256",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
@ -214,27 +201,18 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
"--num-tokens",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of epochs to train.",
|
||||
default=10000000000,
|
||||
help="Number of tokens to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-batch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --start-epoch is ignored and
|
||||
help="""If positive, we start training from this batch and "
|
||||
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
||||
""",
|
||||
)
|
||||
@ -270,12 +248,11 @@ def get_parser():
|
||||
"--lr-tokens",
|
||||
type=float,
|
||||
default=1000000000,
|
||||
help="""Number of tokens beyond which the LR will start to decrease per token, defines
|
||||
LR schedule, replacing lr-epochs
|
||||
help="""Number of tokens beyond which the LR will start to significantly
|
||||
decrease per token, defines LR schedules
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -305,8 +282,6 @@ def get_parser():
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
||||
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
||||
end of each epoch where `xxx` is the epoch number counting from 0.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -357,21 +332,9 @@ def get_params() -> AttributeDict:
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
- batch_idx_train: It contains number of batches trained so far.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
- num_tokens_seen: Total number of tokens that have been seen so far.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
@ -393,11 +356,8 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"num_tokens_seen": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
@ -477,9 +437,7 @@ def load_checkpoint_if_available(
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_batch is positive, it will load the checkpoint from
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
||||
params.start_epoch is larger than 1, it will load the checkpoint from
|
||||
`params.start_epoch - 1`.
|
||||
`params.exp_dir/checkpoint-{params.start_batch}.pt`.
|
||||
|
||||
Apart from loading state dict for `model` and `optimizer` it also updates
|
||||
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
@ -573,14 +531,6 @@ def save_checkpoint(
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
def _encode_texts_as_bytes(texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Encode texts as bytes and then integer tensors.
|
||||
@ -627,8 +577,8 @@ def compute_loss(
|
||||
model:
|
||||
The model for training. It is an instance of Subformer in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
A batch of data: a tensor of integers from 0 to 255, of shape
|
||||
(num_sequences, sequence_length).
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
@ -655,8 +605,14 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# this logprob can be treated as somewhat like the log of the 'ppl1' printed in SRILM:
|
||||
# that is, the total log-probability of the sequence, divided by the
|
||||
# probability of just the non-terminating elements. (treating \0 as
|
||||
# a terminator, like EOF). In fact this is not 100% correct, since
|
||||
# we may also pad with leading zeros in case the 'window' starts before
|
||||
# the start of the file. But this is a small effect if the files are long.
|
||||
info["frames"] = (
|
||||
labels.numel()
|
||||
(labels != 0).sum()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
@ -697,7 +653,7 @@ def compute_validation_loss(
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
def train(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
@ -746,8 +702,6 @@ def train_one_epoch(
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||
|
||||
saved_bad_model = False
|
||||
def save_bad_model(suffix: str = ""):
|
||||
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
|
||||
@ -761,11 +715,11 @@ def train_one_epoch(
|
||||
|
||||
|
||||
for batch_idx_, batch in enumerate(train_dl):
|
||||
batch_idx = batch_idx_ + batch_idx_offset
|
||||
if batch_idx % 10 == 0:
|
||||
params.batch_idx_train += 1
|
||||
|
||||
if params.batch_idx_train % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
|
||||
params.batch_idx_train += 1
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
@ -782,9 +736,12 @@ def train_one_epoch(
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
tokens_seen = params.batch_idx_train * params.bytes_per_segment * params.batch_size * get_world_size()
|
||||
# we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens.
|
||||
scheduler.step_epoch(tokens_seen)
|
||||
scheduler.step_epoch(params.num_tokens_seen)
|
||||
|
||||
# this doesn't take into account padding, but it doesn't matter
|
||||
# much, it is just to determine when we terminate.
|
||||
params.num_tokens_seen += params.bytes_per_segment * params.batch_size * get_world_size()
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
@ -812,7 +769,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
params.cur_batch_idx = batch_idx
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
@ -824,12 +780,16 @@ def train_one_epoch(
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
# wait till just after writing a checkpoint to finish training,
|
||||
# to avoid wasted training iterations.
|
||||
if params.num_tokens_seen > params.num_tokens:
|
||||
break
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
@ -853,7 +813,7 @@ def train_one_epoch(
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"Epoch {params.num_tokens_seen / params.tokens_per_epoch:.3f}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], tokens: {tokens_seen} "
|
||||
f"lr: {cur_lr:.2e}, " +
|
||||
@ -871,13 +831,12 @@ def train_one_epoch(
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", params.num_tokens_seen / params.tokens_per_epoch)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
@ -887,7 +846,7 @@ def train_one_epoch(
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(f"Epoch {params.num_tokens_seen/params.tokens_per_epoch:.3f}, batch {params.batch_idx_train}, validation: {valid_info}")
|
||||
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
@ -896,9 +855,6 @@ def train_one_epoch(
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
@ -991,6 +947,7 @@ def run(rank, world_size, args):
|
||||
train = LmDataset(params.train_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment,
|
||||
skip_to_batch_idx=getattr(params, 'cur_batch_idx', 0))
|
||||
params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress.
|
||||
|
||||
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
|
||||
|
||||
@ -1015,47 +972,32 @@ def run(rank, world_size, args):
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
# we don't do step_epoch per epoch as the dataset might be large, we do this
|
||||
# to let it know how many tokens we have processed so far, and have a
|
||||
# soft-cutoff lr_tokens measured in tokens.
|
||||
# scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch)
|
||||
# the above will affect random seeds in the dataloaders.
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
# the "+ params.start_batch" is to ensure that we use a different random
|
||||
# seed generator in the data loaders if we resume training using --start-batch;
|
||||
# this will prevent us from using the exact same data as we used before, although
|
||||
# at the expense of exact repeatability.
|
||||
fix_random_seed(params.seed * 123456 + params.start_batch)
|
||||
# the above will affect random seeds in the dataloaders.
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
batch_idx_offset=(getattr(params, 'cur_batch_idx', 0) if epoch == params.start_epoch else 0),
|
||||
)
|
||||
train(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if params.print_diagnostics:
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user