diff --git a/egs/libriheavy/LM/zipformer1/chunk_decoder.py b/egs/libriheavy/LM/zipformer1/chunk_decoder.py index e5da99aa6..823df602e 100644 --- a/egs/libriheavy/LM/zipformer1/chunk_decoder.py +++ b/egs/libriheavy/LM/zipformer1/chunk_decoder.py @@ -115,10 +115,10 @@ class ChunkDecoder(nn.Module): logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len) - if random.random() < 0.01: + if random.random() < 0.02: # occasionally print out average logprob per position in the chunk. l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1)) l = l.to('cpu').tolist() - logging.info("Logprobs per position in chunk: {l}") + logging.info(l"Logprobs per position in chunk: {l}") return logprobs diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index bb9e00afc..8f7abed90 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -870,7 +870,7 @@ def train_one_epoch( # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() scheduler.step_batch(params.batch_idx_train) - tokens_seen = params.batch_idx_train * params.bytes_per_segment * params_batch_size * get_world_size() + tokens_seen = params.batch_idx_train * params.bytes_per_segment * params.batch_size * get_world_size() # we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens. scheduler.step_epoch(tokens_seen)