This commit is contained in:
Daniel Povey 2023-05-04 15:39:36 +08:00
parent 86c2c60100
commit 45f5e9981d
2 changed files with 3 additions and 3 deletions

View File

@ -115,10 +115,10 @@ class ChunkDecoder(nn.Module):
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len) logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
if random.random() < 0.01: if random.random() < 0.02:
# occasionally print out average logprob per position in the chunk. # occasionally print out average logprob per position in the chunk.
l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1)) l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1))
l = l.to('cpu').tolist() l = l.to('cpu').tolist()
logging.info("Logprobs per position in chunk: {l}") logging.info(l"Logprobs per position in chunk: {l}")
return logprobs return logprobs

View File

@ -870,7 +870,7 @@ def train_one_epoch(
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
tokens_seen = params.batch_idx_train * params.bytes_per_segment * params_batch_size * get_world_size() tokens_seen = params.batch_idx_train * params.bytes_per_segment * params.batch_size * get_world_size()
# we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens. # we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens.
scheduler.step_epoch(tokens_seen) scheduler.step_epoch(tokens_seen)