diff --git a/egs/libriheavy/LM/zipformer1/chunk_decoder.py b/egs/libriheavy/LM/zipformer1/chunk_decoder.py index 5f3a84be0..e5da99aa6 100644 --- a/egs/libriheavy/LM/zipformer1/chunk_decoder.py +++ b/egs/libriheavy/LM/zipformer1/chunk_decoder.py @@ -16,7 +16,8 @@ # limitations under the License. - +import logging +import random import torch from torch import nn, Tensor @@ -114,4 +115,10 @@ class ChunkDecoder(nn.Module): logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len) + if random.random() < 0.01: + # occasionally print out average logprob per position in the chunk. + l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1)) + l = l.to('cpu').tolist() + logging.info("Logprobs per position in chunk: {l}") + return logprobs diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 5f0715be6..bb9e00afc 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -81,7 +81,7 @@ from icefall.checkpoint import ( update_averaged_model, ) from icefall.hooks import register_inf_check_hooks -from icefall.dist import cleanup_dist, setup_dist +from icefall.dist import cleanup_dist, setup_dist, get_world_size from icefall.env import get_env_info from icefall.utils import ( AttributeDict, @@ -295,10 +295,11 @@ def get_parser(): ) parser.add_argument( - "--lr-epochs", + "--lr-tokens", type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. + default=1000000000, + help="""Number of tokens beyond which the LR will start to decrease per token, defines + LR schedule, replacing lr-epochs """, ) @@ -869,6 +870,9 @@ def train_one_epoch( # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) + tokens_seen = params.batch_idx_train * params.bytes_per_segment * params_batch_size * get_world_size() + # we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens. + scheduler.step_epoch(tokens_seen) scaler.step(optimizer) scaler.update() @@ -939,7 +943,7 @@ def train_one_epoch( logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], " + f"tot_loss[{tot_loss}], tokens: {tokens_seen} " f"lr: {cur_lr:.2e}, " + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") ) @@ -1049,7 +1053,7 @@ def run(rank, world_size, args): clipping_scale=2.0, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + scheduler = Eden(optimizer, params.lr_batches, params.lr_tokens) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1074,7 +1078,7 @@ def run(rank, world_size, args): train = LmDataset(params.train_file_list, - bytes_per_segment=params.bytes_per_segment) + bytes_per_segment=params.bytes_per_segment,) train_dl = LmDataloader(train, batch_size=params.batch_size, num_workers=params.num_workers) @@ -1091,7 +1095,10 @@ def run(rank, world_size, args): scaler.load_state_dict(checkpoints["grad_scaler"]) for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) + # we don't do step_epoch per epoch as the dataset might be large, we do this + # to let it know how many tokens we have processed so far, and have a + # soft-cutoff lr_tokens measured in tokens. + # scheduler.step_epoch(epoch - 1) fix_random_seed(params.seed + epoch - 1) # the above will affect random seeds in the dataloaders. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index ea9f82bb8..4d99983f6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -809,7 +809,7 @@ class LRScheduler(object): self.batch = self.batch + 1 self._set_lrs() - def step_epoch(self, epoch: Optional[int] = None): + def step_epoch(self, epoch: Optional[Union[int, float]] = None): # Step the epoch index, or just set it. If you provide the 'epoch' arg, # you should call this at the start of the epoch; if you don't provide the 'epoch' # arg, you should call it at the end of the epoch.