Remove concept of epochs from training subformer for language modeling;

revert dimensions to how they were in zlm53.
This commit is contained in:
Daniel Povey 2023-06-19 04:43:39 +08:00
parent c7e8a7349d
commit 03ad0d7910
3 changed files with 157 additions and 240 deletions

View File

@ -43,20 +43,11 @@ def get_parser():
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
help="""If positive, it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
@ -68,18 +59,15 @@ def get_parser():
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
"'--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
help="Whether to load averaged model. If True, it would decode "
"with the averaged model over this many checkpoints."
)
parser.add_argument(
@ -150,10 +138,8 @@ def main():
params.res_dir = params.exp_dir / "log-evaluation"
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.iter > 0
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
@ -173,81 +159,51 @@ def main():
logging.info(f"Number of model parameters: {num_param}")
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()

View File

@ -41,15 +41,20 @@ class LmDataset(torch.utils.data.IterableDataset):
world_size: int = 1,
rank: int = 0,
training: bool = True,
skip_to_batch_idx: int = 0,
):
"""
Initialize LmDataset object. Args:
Initialize LmDataset object. This keeps no state, it just gives you a totally random
segment each time. The training files are just viewed as sequences of bytes, from which
we select chunks of a fixed size. In training mode we just loop infinitely, and let
the training code decide when to stop based on the count of tokens. In test mode
we loop so that we see each byte about once.
Args:
file_list_fn: a file in which each line contains: a number of bytes, then a space, then a filename.
e.g. a line might contain the text "64324 foo/abc.txt".
(filenames can not contain spaces).
world_size, rank: from DDP. We get the data-loader id and world-size separately.
bytes_per_segment: the number of bytes in each segment of data.
skip_to_batch_idx: if provided, the first time we iterate we will skip this many batches.
"""
self.training = training
self.skip_to_batch_idx = skip_to_batch_idx
@ -66,18 +71,25 @@ class LmDataset(torch.utils.data.IterableDataset):
fn = line[len(num_bytes) + 1:] # this works even if fn has spaces in
self.files.append(fn)
self.num_bytes.append(int(num_bytes))
tot_bytes = sum(self.num_bytes)
N = len(self.num_bytes)
self.probs = np.array([ x / tot_bytes for x in self.num_bytes ])
# For purposes of choosing the possible start-positions of a segment: we
# need to pad on the left by bytes_per_segment - 1. This is part of a
# scheme to ensure that each byte in each training file is chosen with
# equal probability, while also choosing different shifts of the data
# with equal probability. We end up padding with zeroes if we
# are outside the file either on the left or the right.
pad = self.bytes_per_segment - 1
tot_positions = sum([ x + pad for x in self.num_bytes])
self.probs = np.array([ (x + pad) / tot_positions for x in self.num_bytes ])
self.tot_positions = tot_positions
worker_info = torch.utils.data.get_worker_info()
num_workers = (1 if worker_info is None else worker_info.num_workers)
# world_size is for ddp training, num_workers for data-loader worker threads.
# num_workers for data-loader worker threads; world_size is for ddp training.
tot_workers = num_workers * get_world_size()
self.num_segments = tot_bytes // (bytes_per_segment * tot_workers)
self.num_segments = float('inf') if training else 1 + tot_positions // (bytes_per_segment * tot_workers)
def __iter__(self):
@ -85,19 +97,22 @@ class LmDataset(torch.utils.data.IterableDataset):
# id includes both worker (within training job) and rank of training job
my_id = (0 if worker_info is None else worker_info.id) + 1000 * self.ddp_rank
# note: the seed depends on the current random state, which will be different
# depending on the DDP worker id and also depending which batch we restarted
# training on. This does not guarantee that you get repeatability if you
# restart training, but it does ensure you don't see exactly repeated data.
seed = (random.randint(0, 10000) if self.training else 0) + my_id
# the next line is because, for some reason, when we ran with --worle-size more than 1,
# this info message was not printed out.
logging.getLogger().setLevel(logging.INFO)
logging.info(f"my_id={my_id}, seed={seed}, num_segments={self.num_segments}")
# use numpy's generator, not random's, because we need np.random.multinomial.
rng = np.random.default_rng(seed=seed)
skip_to_batch_idx = self.skip_to_batch_idx
if skip_to_batch_idx != 0:
logging.info(f"skip-to-batch-idx={skip_to_batch_idx}")
self.skip_to_batch_idx = 0 # so only the 1st time we iterate, we respect this.
n = 0
while n < self.num_segments: # if self.num_segments is infinity, just keep going.
n += 1
for n in range(self.num_segments):
# np.random.multinomial / np.random.Generator.multinomial has an interface
# where it gives counts of different categories, instead of the chosen category,
# so we need to use np.nonzero to get the chosen category (i.e. the file index)
@ -106,30 +121,34 @@ class LmDataset(torch.utils.data.IterableDataset):
file_idx, = np.nonzero(rng.multinomial(1, self.probs))
file_idx, = file_idx
if n < skip_to_batch_idx:
continue
fn = self.files[file_idx]
num_bytes = self.num_bytes[file_idx]
# begin_pos, end_pos are the begin,end of a range from which we'll pick
# randomly, for where the start of the segment might be.
begin_pos = 0
end_pos = max(1, num_bytes - self.bytes_per_segment)
# randomly, for where the start of the segment might be. We only
# guarantee that a segment should contain at most one byte of data;
# this helps ensure that each byte is chosen with the exact same probability,
# which is easier for analysis.
begin_pos = - (self.bytes_per_segment - 1)
end_pos = max(1, num_bytes - 1)
begin, = rng.integers(low=begin_pos, high=end_pos, size=1)
with open(fn, "rb") as f:
f.seek(begin)
b = f.read(self.bytes_per_segment) # b is bytes object
read_size = len(b)
if read_size < self.bytes_per_segment:
b = b + b'\0' * (self.bytes_per_segment - read_size)
if begin >= 0:
f.seek(begin)
b = f.read(self.bytes_per_segment) # b is bytes object
else:
b = b'\0' * -begin + f.read(self.bytes_per_segment + begin)
if len(b) < self.bytes_per_segment:
b = b + b'\0' * (self.bytes_per_segment - len(b))
yield torch.Tensor(np.frombuffer(b, dtype=np.uint8).copy()).to(torch.long)
def tot_tokens(self):
# Returns the total number of tokens, including padding tokens, in
# the dataset; this is for purposes of figuring out how many we
# epochs we have trained for.
return self.tot_positions
def _test():

View File

@ -22,24 +22,11 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
./zipformer1/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 300
--exp-dir zipformer1/exp \
--use-fp16 True
# For mix precision training:
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 550
"""
@ -143,7 +130,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-dim",
type=str,
default="384,512,512,768,512,512,384",
default="256,384,512,768,512,384,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
)
@ -214,27 +201,18 @@ def get_parser():
)
parser.add_argument(
"--num-epochs",
"--num-tokens",
type=int,
default=30,
help="Number of epochs to train.",
default=10000000000,
help="Number of tokens to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch. It should be positive.
If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
help="""If positive, we start training from this batch and "
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
@ -270,12 +248,11 @@ def get_parser():
"--lr-tokens",
type=float,
default=1000000000,
help="""Number of tokens beyond which the LR will start to decrease per token, defines
LR schedule, replacing lr-epochs
help="""Number of tokens beyond which the LR will start to significantly
decrease per token, defines LR schedules
""",
)
parser.add_argument(
"--seed",
type=int,
@ -305,8 +282,6 @@ def get_parser():
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
@ -357,21 +332,9 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- batch_idx_train: It contains number of batches trained so far.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- num_tokens_seen: Total number of tokens that have been seen so far.
- log_interval: Print training loss if batch_idx % log_interval` is 0
@ -393,11 +356,8 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"num_tokens_seen": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000,
@ -477,9 +437,7 @@ def load_checkpoint_if_available(
"""Load checkpoint from file.
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`.
`params.exp_dir/checkpoint-{params.start_batch}.pt`.
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
@ -573,14 +531,6 @@ def save_checkpoint(
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def _encode_texts_as_bytes(texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor, Tensor]:
"""
Encode texts as bytes and then integer tensors.
@ -627,8 +577,8 @@ def compute_loss(
model:
The model for training. It is an instance of Subformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
A batch of data: a tensor of integers from 0 to 255, of shape
(num_sequences, sequence_length).
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
@ -655,8 +605,14 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# this logprob can be treated as somewhat like the log of the 'ppl1' printed in SRILM:
# that is, the total log-probability of the sequence, divided by the
# probability of just the non-terminating elements. (treating \0 as
# a terminator, like EOF). In fact this is not 100% correct, since
# we may also pad with leading zeros in case the 'window' starts before
# the start of the file. But this is a small effect if the files are long.
info["frames"] = (
labels.numel()
(labels != 0).sum()
)
# Note: We use reduction=sum while computing the loss.
@ -697,7 +653,7 @@ def compute_validation_loss(
return tot_loss
def train_one_epoch(
def train(
params: AttributeDict,
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
@ -746,8 +702,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
saved_bad_model = False
def save_bad_model(suffix: str = ""):
save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt",
@ -761,11 +715,11 @@ def train_one_epoch(
for batch_idx_, batch in enumerate(train_dl):
batch_idx = batch_idx_ + batch_idx_offset
if batch_idx % 10 == 0:
params.batch_idx_train += 1
if params.batch_idx_train % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params))
params.batch_idx_train += 1
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
@ -782,9 +736,12 @@ def train_one_epoch(
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
tokens_seen = params.batch_idx_train * params.bytes_per_segment * params.batch_size * get_world_size()
# we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens.
scheduler.step_epoch(tokens_seen)
scheduler.step_epoch(params.num_tokens_seen)
# this doesn't take into account padding, but it doesn't matter
# much, it is just to determine when we terminate.
params.num_tokens_seen += params.bytes_per_segment * params.batch_size * get_world_size()
scaler.step(optimizer)
scaler.update()
@ -812,7 +769,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -824,12 +780,16 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
# wait till just after writing a checkpoint to finish training,
# to avoid wasted training iterations.
if params.num_tokens_seen > params.num_tokens:
break
if batch_idx % 100 == 0 and params.use_fp16:
# If the grad scale was less than 1, try increasing it. The _growth_interval
@ -853,7 +813,7 @@ def train_one_epoch(
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
logging.info(
f"Epoch {params.cur_epoch}, "
f"Epoch {params.num_tokens_seen / params.tokens_per_epoch:.3f}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], tokens: {tokens_seen} "
f"lr: {cur_lr:.2e}, " +
@ -871,13 +831,12 @@ def train_one_epoch(
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", params.num_tokens_seen / params.tokens_per_epoch)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
@ -887,7 +846,7 @@ def train_one_epoch(
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info(f"Epoch {params.num_tokens_seen/params.tokens_per_epoch:.3f}, batch {params.batch_idx_train}, validation: {valid_info}")
logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB")
if tb_writer is not None:
valid_info.write_summary(
@ -896,9 +855,6 @@ def train_one_epoch(
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
@ -991,6 +947,7 @@ def run(rank, world_size, args):
train = LmDataset(params.train_file_list,
bytes_per_segment=params.bytes_per_segment,
skip_to_batch_idx=getattr(params, 'cur_batch_idx', 0))
params.tokens_per_epoch = train.num_tokens() # helps us figure out epoch progress.
batch_size = params.batch_size // (6 if params.print_diagnostics else 1)
@ -1015,47 +972,32 @@ def run(rank, world_size, args):
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
# we don't do step_epoch per epoch as the dataset might be large, we do this
# to let it know how many tokens we have processed so far, and have a
# soft-cutoff lr_tokens measured in tokens.
# scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch)
# the above will affect random seeds in the dataloaders.
if tb_writer is not None:
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
# the "+ params.start_batch" is to ensure that we use a different random
# seed generator in the data loaders if we resume training using --start-batch;
# this will prevent us from using the exact same data as we used before, although
# at the expense of exact repeatability.
fix_random_seed(params.seed * 123456 + params.start_batch)
# the above will affect random seeds in the dataloaders.
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
batch_idx_offset=(getattr(params, 'cur_batch_idx', 0) if epoch == params.start_epoch else 0),
)
train(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
logging.info("Done!")