Step lr_scheduler on tokens not epoch; add some more debug output

This commit is contained in:
Daniel Povey 2023-05-04 15:32:12 +08:00
parent 3574e7dbb5
commit 86c2c60100
3 changed files with 24 additions and 10 deletions

View File

@ -16,7 +16,8 @@
# limitations under the License.
import logging
import random
import torch
from torch import nn, Tensor
@ -114,4 +115,10 @@ class ChunkDecoder(nn.Module):
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
if random.random() < 0.01:
# occasionally print out average logprob per position in the chunk.
l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1))
l = l.to('cpu').tolist()
logging.info("Logprobs per position in chunk: {l}")
return logprobs

View File

@ -81,7 +81,7 @@ from icefall.checkpoint import (
update_averaged_model,
)
from icefall.hooks import register_inf_check_hooks
from icefall.dist import cleanup_dist, setup_dist
from icefall.dist import cleanup_dist, setup_dist, get_world_size
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
@ -295,10 +295,11 @@ def get_parser():
)
parser.add_argument(
"--lr-epochs",
"--lr-tokens",
type=float,
default=3.5,
help="""Number of epochs that affects how rapidly the learning rate decreases.
default=1000000000,
help="""Number of tokens beyond which the LR will start to decrease per token, defines
LR schedule, replacing lr-epochs
""",
)
@ -869,6 +870,9 @@ def train_one_epoch(
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
tokens_seen = params.batch_idx_train * params.bytes_per_segment * params_batch_size * get_world_size()
# we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens.
scheduler.step_epoch(tokens_seen)
scaler.step(optimizer)
scaler.update()
@ -939,7 +943,7 @@ def train_one_epoch(
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], "
f"tot_loss[{tot_loss}], tokens: {tokens_seen} "
f"lr: {cur_lr:.2e}, " +
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
@ -1049,7 +1053,7 @@ def run(rank, world_size, args):
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
scheduler = Eden(optimizer, params.lr_batches, params.lr_tokens)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
@ -1074,7 +1078,7 @@ def run(rank, world_size, args):
train = LmDataset(params.train_file_list,
bytes_per_segment=params.bytes_per_segment)
bytes_per_segment=params.bytes_per_segment,)
train_dl = LmDataloader(train, batch_size=params.batch_size,
num_workers=params.num_workers)
@ -1091,7 +1095,10 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
# we don't do step_epoch per epoch as the dataset might be large, we do this
# to let it know how many tokens we have processed so far, and have a
# soft-cutoff lr_tokens measured in tokens.
# scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
# the above will affect random seeds in the dataloaders.

View File

@ -809,7 +809,7 @@ class LRScheduler(object):
self.batch = self.batch + 1
self._set_lrs()
def step_epoch(self, epoch: Optional[int] = None):
def step_epoch(self, epoch: Optional[Union[int, float]] = None):
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
# you should call this at the start of the epoch; if you don't provide the 'epoch'
# arg, you should call it at the end of the epoch.