mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp106' into scaled_adam_exp108
# Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py
This commit is contained in:
commit
63334137ee
@ -304,18 +304,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
attn_scores_in: Optional[Tensor] = None,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
attn_scores_in: something with the dimension fo attention weights (bsz, len, len, num_heads) that is
|
||||
passed from layer to layer.
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
batch_split: if not None, this layer will only be applied to
|
||||
@ -333,10 +330,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att, attn_weights, attn_scores_out = self.self_attn(
|
||||
src_att, attn_weights = self.self_attn(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_scores_in=attn_scores_in,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
@ -362,7 +358,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
bypass_scale = bypass_scale.clamp(min=0.1, max=1.0)
|
||||
src = src_orig + delta * self.bypass_scale
|
||||
|
||||
return src, attn_scores_out
|
||||
return src
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
@ -520,7 +516,6 @@ class ConformerEncoder(nn.Module):
|
||||
output = src
|
||||
|
||||
outputs = []
|
||||
attn_scores = None
|
||||
|
||||
|
||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||
@ -531,10 +526,9 @@ class ConformerEncoder(nn.Module):
|
||||
for i, mod in enumerate(self.layers):
|
||||
if i in layers_to_drop:
|
||||
continue
|
||||
output, attn_scores = mod(
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
attn_scores,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
@ -856,8 +850,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
||||
@ -875,7 +867,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
attn_scores_in: Optional[Tensor],
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
@ -897,7 +888,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- attn_scores_in: :math:`(N, L, L, num_heads)`
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
@ -910,19 +900,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
- Returns: (attn_output, attn_weights, attn_scores)
|
||||
- Returns: (attn_output, attn_weights)
|
||||
|
||||
- attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads
|
||||
and S is the sequence length.
|
||||
- attn_scores: :math:`(N, S, S, H)`, these are the attn weights
|
||||
before softmax.
|
||||
"""
|
||||
x, weights, scores = self.multi_head_attention_forward(
|
||||
x, weights = self.multi_head_attention_forward(
|
||||
self.in_max_eig(self.in_balancer(self.in_proj(x))),
|
||||
pos_emb,
|
||||
None if attn_scores_in is None else torch.matmul(attn_scores_in, self.attn_scores_proj_in),
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj.weight,
|
||||
@ -934,15 +921,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
if attn_scores_in is not None:
|
||||
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out)
|
||||
attn_scores_out = attn_scores_out + attn_scores_in
|
||||
else:
|
||||
# Here, add self.attn_scores_proj_in in order to make sure it has
|
||||
# a grad.
|
||||
attn_scores_out = torch.matmul(scores, self.attn_scores_proj_out +
|
||||
self.attn_scores_proj_in)
|
||||
return x, weights, attn_scores_out
|
||||
return x, weights
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
@ -973,7 +952,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
attn_scores_in: Optional[Tensor],
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
@ -1028,8 +1006,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
E is the embedding dimension.
|
||||
- attn_weights: :math:`(N * H, S, S)` where N is the batch size,
|
||||
H is the num-heads, S is the sequence length.
|
||||
- attn_scores: :math:`(N, S, S, H)` where N is the batch size,
|
||||
S is the sequence length and H is the num-heads.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, _ = x.size()
|
||||
@ -1142,10 +1118,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
matrix_ac + matrix_bd
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
attn_scores_out = attn_output_weights.permute(0, 2, 3, 1) # (batch, time1, time2, head)
|
||||
|
||||
if attn_scores_in is not None:
|
||||
attn_output_weights = attn_output_weights + attn_scores_in.permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
@ -1192,7 +1164,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
)
|
||||
|
||||
return attn_output, attn_output_weights, attn_scores_out
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
|
||||
def forward2(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user