mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Step lr_scheduler on tokens not epoch; add some more debug output
This commit is contained in:
parent
3574e7dbb5
commit
86c2c60100
@ -16,7 +16,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
import logging
|
||||
import random
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
@ -114,4 +115,10 @@ class ChunkDecoder(nn.Module):
|
||||
|
||||
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
|
||||
|
||||
if random.random() < 0.01:
|
||||
# occasionally print out average logprob per position in the chunk.
|
||||
l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1))
|
||||
l = l.to('cpu').tolist()
|
||||
logging.info("Logprobs per position in chunk: {l}")
|
||||
|
||||
return logprobs
|
||||
|
||||
@ -81,7 +81,7 @@ from icefall.checkpoint import (
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.dist import cleanup_dist, setup_dist, get_world_size
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -295,10 +295,11 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-epochs",
|
||||
"--lr-tokens",
|
||||
type=float,
|
||||
default=3.5,
|
||||
help="""Number of epochs that affects how rapidly the learning rate decreases.
|
||||
default=1000000000,
|
||||
help="""Number of tokens beyond which the LR will start to decrease per token, defines
|
||||
LR schedule, replacing lr-epochs
|
||||
""",
|
||||
)
|
||||
|
||||
@ -869,6 +870,9 @@ def train_one_epoch(
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
tokens_seen = params.batch_idx_train * params.bytes_per_segment * params_batch_size * get_world_size()
|
||||
# we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens.
|
||||
scheduler.step_epoch(tokens_seen)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
@ -939,7 +943,7 @@ def train_one_epoch(
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], "
|
||||
f"tot_loss[{tot_loss}], tokens: {tokens_seen} "
|
||||
f"lr: {cur_lr:.2e}, " +
|
||||
(f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
@ -1049,7 +1053,7 @@ def run(rank, world_size, args):
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_tokens)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
logging.info("Loading optimizer state dict")
|
||||
@ -1074,7 +1078,7 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
train = LmDataset(params.train_file_list,
|
||||
bytes_per_segment=params.bytes_per_segment)
|
||||
bytes_per_segment=params.bytes_per_segment,)
|
||||
train_dl = LmDataloader(train, batch_size=params.batch_size,
|
||||
num_workers=params.num_workers)
|
||||
|
||||
@ -1091,7 +1095,10 @@ def run(rank, world_size, args):
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
# we don't do step_epoch per epoch as the dataset might be large, we do this
|
||||
# to let it know how many tokens we have processed so far, and have a
|
||||
# soft-cutoff lr_tokens measured in tokens.
|
||||
# scheduler.step_epoch(epoch - 1)
|
||||
fix_random_seed(params.seed + epoch - 1)
|
||||
# the above will affect random seeds in the dataloaders.
|
||||
|
||||
|
||||
@ -809,7 +809,7 @@ class LRScheduler(object):
|
||||
self.batch = self.batch + 1
|
||||
self._set_lrs()
|
||||
|
||||
def step_epoch(self, epoch: Optional[int] = None):
|
||||
def step_epoch(self, epoch: Optional[Union[int, float]] = None):
|
||||
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
||||
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
||||
# arg, you should call it at the end of the epoch.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user