mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert warmup_batches change; make code change to avoid non in attn_weights
This commit is contained in:
parent
b0c87a93d2
commit
f9f546968c
@ -1119,8 +1119,7 @@ def run(rank, world_size, args):
|
|||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs,
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
warmup_batches=1000.0)
|
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
logging.info("Loading optimizer state dict")
|
logging.info("Loading optimizer state dict")
|
||||||
|
|||||||
@ -1412,13 +1412,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert attn_mask.dtype == torch.bool
|
assert attn_mask.dtype == torch.bool
|
||||||
attn_scores = attn_scores.masked_fill(attn_mask, float("-inf"))
|
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
||||||
|
# all scores zero. It's important that this be large enough that exp(-1000)
|
||||||
|
# is exactly zero, for reasons related to const_attention_rate, it
|
||||||
|
# compares the final weights with zero.
|
||||||
|
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
||||||
attn_scores = attn_scores.masked_fill(
|
attn_scores = attn_scores.masked_fill(
|
||||||
key_padding_mask.unsqueeze(1),
|
key_padding_mask.unsqueeze(1),
|
||||||
float("-inf"),
|
-1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We use our own version of softmax, defined in scaling.py, which should
|
# We use our own version of softmax, defined in scaling.py, which should
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user