mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fix
This commit is contained in:
parent
86c2c60100
commit
45f5e9981d
@ -115,10 +115,10 @@ class ChunkDecoder(nn.Module):
|
||||
|
||||
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
|
||||
|
||||
if random.random() < 0.01:
|
||||
if random.random() < 0.02:
|
||||
# occasionally print out average logprob per position in the chunk.
|
||||
l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1))
|
||||
l = l.to('cpu').tolist()
|
||||
logging.info("Logprobs per position in chunk: {l}")
|
||||
logging.info(l"Logprobs per position in chunk: {l}")
|
||||
|
||||
return logprobs
|
||||
|
||||
@ -870,7 +870,7 @@ def train_one_epoch(
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
tokens_seen = params.batch_idx_train * params.bytes_per_segment * params_batch_size * get_world_size()
|
||||
tokens_seen = params.batch_idx_train * params.bytes_per_segment * params.batch_size * get_world_size()
|
||||
# we make the formula depend on tokens not epochs, replacing lr_epochs with lr_tokens.
|
||||
scheduler.step_epoch(tokens_seen)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user