issues fixed

This commit is contained in:
JinZr 2023-07-24 17:30:24 +08:00
parent 739e2a22c6
commit 49e9d15733
3 changed files with 477 additions and 153 deletions

View File

@ -15,7 +15,142 @@ from scaling import (
softmax, softmax,
) )
from torch import Tensor, nn from torch import Tensor, nn
from zipformer import CompactRelPositionalEncoding, SelfAttention, _whitening_schedule from zipformer import CompactRelPositionalEncoding, _whitening_schedule
from icefall.utils import make_pad_mask
class CrossAttention(nn.Module):
"""
The simplest possible attention module. This one works with already-computed attention
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
Args:
embed_dim: the input and output embedding dimension
num_heads: the number of attention heads
value_head_dim: the value dimension per head
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
value_head_dim: int,
) -> None:
super().__init__()
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
self.out_proj = ScaledLinear(
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
)
self.whiten = Whiten(
num_groups=1,
whitening_limit=_whitening_schedule(7.5, ratio=3.0),
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(
self,
x: Tensor,
attn_weights: Tensor,
) -> Tensor:
"""
Args:
x: input tensor, of shape (seq_len, batch_size, embed_dim)
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
attn_weights.sum(dim=-1) == 1.
Returns:
a tensor with the same shape as x.
"""
(am_seq_len, batch_size, embed_dim) = x.shape
(_, _, lm_seq_len, _) = attn_weights.shape
num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, lm_seq_len, am_seq_len)
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
# print("projected x.shape", x.shape)
x = x.reshape(am_seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, am_seq_len, value_head_dim)
# print("permuted x.shape", x.shape)
value_head_dim = x.shape[-1]
# todo: see whether there is benefit in overriding matmul
x = torch.matmul(attn_weights, x)
# v: (num_heads, batch_size, lm_seq_len, value_head_dim)
# print("attended x.shape", x.shape)
x = (
x.permute(2, 1, 0, 3)
.contiguous()
.view(lm_seq_len, batch_size, num_heads * value_head_dim)
)
# returned value is of shape (lm_seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x)
x = self.whiten(x)
# print("returned x.shape", x.shape)
return x
def streaming_forward(
self,
x: Tensor,
attn_weights: Tensor,
cached_val: Tensor,
left_context_len: int,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x: input tensor, of shape (seq_len, batch_size, embed_dim)
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
attn_weights.sum(dim=-1) == 1.
cached_val: cached attention value tensor of left context,
of shape (left_context_len, batch_size, value_dim)
left_context_len: number of left context frames.
Returns:
- attention weighted output, a tensor with the same shape as x.
- updated cached attention value tensor of left context.
"""
(seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0]
seq_len2 = seq_len + left_context_len
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
# Pad cached left contexts
assert cached_val.shape[0] == left_context_len, (
cached_val.shape[0],
left_context_len,
)
x = torch.cat([cached_val, x], dim=0)
# Update cached left contexts
cached_val = x[-left_context_len:, ...]
x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, seq_len, value_head_dim)
value_head_dim = x.shape[-1]
# todo: see whether there is benefit in overriding matmul
x = torch.matmul(attn_weights, x)
# v: (num_heads, batch_size, seq_len, value_head_dim)
x = (
x.permute(2, 1, 0, 3)
.contiguous()
.view(seq_len, batch_size, num_heads * value_head_dim)
)
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x)
return x, cached_val
class RelPositionMultiheadAttentionWeights(nn.Module): class RelPositionMultiheadAttentionWeights(nn.Module):
@ -40,7 +175,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
def __init__( def __init__(
self, self,
embed_dim: int = 512, lm_embed_dim: int = 512,
am_embed_dim: int = 512,
pos_dim: int = 192, pos_dim: int = 192,
num_heads: int = 5, num_heads: int = 5,
query_head_dim: int = 32, query_head_dim: int = 32,
@ -49,7 +185,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.lm_embed_dim = lm_embed_dim
self.am_embed_dim = am_embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.query_head_dim = query_head_dim self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim self.pos_head_dim = pos_head_dim
@ -67,10 +204,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# to be used with the ScaledAdam optimizer; with most other optimizers, # to be used with the ScaledAdam optimizer; with most other optimizers,
# it would be necessary to apply the scaling factor in the forward function. # it would be necessary to apply the scaling factor in the forward function.
self.in_lm_proj = ScaledLinear( self.in_lm_proj = ScaledLinear(
embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25 lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25
) )
self.in_am_proj = ScaledLinear( self.in_am_proj = ScaledLinear(
embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25 am_embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25
) )
self.whiten_keys = Whiten( self.whiten_keys = Whiten(
@ -117,10 +254,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
) -> Tensor: ) -> Tensor:
r""" r"""
Args: Args:
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) lm_pruned: input of shape (lm_seq_len, batch_size * prune_range, decoder_embed_dim)
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) am_pruned: input of shape (am_seq_len, batch_size * prune_range, encoder_embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) pos_emb: Positional embedding tensor, of shape (1, 2*lm_seq_len - 1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions key_padding_mask: a bool tensor of shape (batch_size * prune_range, am_seq_len). Positions
that are True in this mask will be ignored as sources in the attention weighting. that are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) attn_mask: mask of shape (seq_len, seq_len)
or (seq_len, batch_size * prune_range, batch_size * prune_range), or (seq_len, batch_size * prune_range, batch_size * prune_range),
@ -137,38 +274,45 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
num_heads = self.num_heads num_heads = self.num_heads
( (
seq_len, lm_seq_len,
b_p_dim, b_p_dim,
_, _,
) = lm_pruned.shape # actual dim: (seq_len, batch * prune_range, _) ) = lm_pruned.shape # actual dim: (lm_seq_len, batch * prune_range, _)
(
am_seq_len,
_,
_,
) = am_pruned.shape
query_dim = query_head_dim * num_heads query_dim = query_head_dim * num_heads
# self-attention # self-attention
q = lm_pruned[..., 0:query_dim] # (seq_len, batch * prune_range, query_dim) q = lm_pruned[..., 0:query_dim] # (lm_seq_len, batch * prune_range, query_dim)
k = am_pruned # (seq_len, batch * prune_range, query_dim) k = am_pruned # (am_seq_len, batch * prune_range, query_dim)
# p is the position-encoding query # p is the position-encoding query
p = lm_pruned[ p = lm_pruned[
..., query_dim: ..., query_dim:
] # (seq_len, batch * prune_range, pos_head_dim * num_heads) ] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads)
assert p.shape[-1] == num_heads * pos_head_dim assert p.shape[-1] == num_heads * pos_head_dim
q = self.copy_query(q) # for diagnostics only, does nothing. q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
p = self.copy_pos_query(p) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing.
q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim) q = q.reshape(lm_seq_len, b_p_dim, num_heads, query_head_dim)
p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim) p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim)
k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim) k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim)
# time1 refers to target, time2 refers to source. # time1 refers to target (query: lm), time2 refers to source (key: am).
q = q.permute( q = q.permute(
2, 1, 0, 3 2, 1, 0, 3
) # (head, batch * prune_range, seq_len, query_head_dim) ) # (head, batch * prune_range, lm_seq_len, query_head_dim)
p = p.permute(2, 1, 0, 3) # (head, batch * prune_range, seq_len, pos_head_dim) p = p.permute(
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, seq_len) 2, 1, 0, 3
) # (head, batch * prune_range, lm_seq_len, pos_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, am_seq_len)
attn_scores = torch.matmul(q, k) attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
use_pos_scores = False use_pos_scores = False
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -179,7 +323,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if use_pos_scores: if use_pos_scores:
pos_emb = self.linear_pos(pos_emb) pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1 seq_len2 = 2 * lm_seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
2, 0, 3, 1 2, 0, 3, 1
) )
@ -194,24 +338,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if torch.jit.is_tracing(): if torch.jit.is_tracing():
(num_heads, b_p_dim, time1, n) = pos_scores.shape (num_heads, b_p_dim, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1) rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len) cols = torch.arange(lm_seq_len)
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
indexes = rows + cols indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n) pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes) pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len) pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
else: else:
pos_scores = pos_scores.as_strided( pos_scores = pos_scores.as_strided(
(num_heads, b_p_dim, seq_len, seq_len), (num_heads, b_p_dim, lm_seq_len, lm_seq_len),
( (
pos_scores.stride(0), pos_scores.stride(0),
pos_scores.stride(1), pos_scores.stride(1),
pos_scores.stride(2) - pos_scores.stride(3), pos_scores.stride(2) - pos_scores.stride(3),
pos_scores.stride(3), pos_scores.stride(3),
), ),
storage_offset=pos_scores.stride(3) * (seq_len - 1), storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
) )
attn_scores = attn_scores + pos_scores attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -232,8 +375,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = penalize_abs_values_gt( attn_scores = penalize_abs_values_gt(
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
) )
assert attn_scores.shape == (num_heads, b_p_dim, lm_seq_len, am_seq_len)
assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len)
if attn_mask is not None: if attn_mask is not None:
assert attn_mask.dtype == torch.bool assert attn_mask.dtype == torch.bool
@ -246,7 +388,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == ( assert key_padding_mask.shape == (
b_p_dim, b_p_dim,
seq_len, am_seq_len,
), key_padding_mask.shape ), key_padding_mask.shape
attn_scores = attn_scores.masked_fill( attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1), key_padding_mask.unsqueeze(1),
@ -271,7 +413,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
return attn_weights return attn_weights
def _print_attn_entropy(self, attn_weights: Tensor): def _print_attn_entropy(self, attn_weights: Tensor):
# attn_weights: (num_heads, batch_size, seq_len, seq_len) # attn_weights: (num_heads, batch_size, lm_seq_len, am_seq_len)
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
@ -307,7 +449,7 @@ class AlignmentAttentionModule(nn.Module):
pos_head_dim=pos_head_dim, pos_head_dim=pos_head_dim,
dropout=dropout, dropout=dropout,
) )
self.cross_attn = SelfAttention( self.cross_attn = CrossAttention(
embed_dim=embed_dim, embed_dim=embed_dim,
num_heads=num_heads, num_heads=num_heads,
value_head_dim=value_head_dim, value_head_dim=value_head_dim,
@ -317,57 +459,80 @@ class AlignmentAttentionModule(nn.Module):
) )
def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor: def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor:
# am_pruned : [B, T, prune_range, encoder_dim] if len(am_pruned.shape) == 4 and len(lm_pruned.shape) == 4:
# lm_pruned : [B, T, prune_range, decoder_dim] # src_key_padding_mask = make_pad_mask(am_pruned_lens)
(batch_size, T, prune_range, encoder_dim) = am_pruned.shape
(batch_size, T, prune_range, decoder_dim) = lm_pruned.shape
# am_pruned : [T, B * prune_range, encoder_dim] # am_pruned : [B, am_T, prune_range, encoder_dim]
# lm_pruned : [T, B * prune_range, decoder_dim] # lm_pruned : [B, lm_T, prune_range, decoder_dim]
merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape( (batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape
T, batch_size * prune_range, encoder_dim (batch_size, lm_T, prune_range, decoder_dim) = lm_pruned.shape
)
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
T, batch_size * prune_range, decoder_dim
)
pos_emb = self.pos_encode(merged_am_pruned) # merged_am_pruned : [am_T, B * prune_range, encoder_dim]
# merged_lm_pruned : [lm_T, B * prune_range, decoder_dim]
merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape(
am_T, batch_size * prune_range, encoder_dim
)
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
lm_T, batch_size * prune_range, decoder_dim
)
pos_emb = self.pos_encode(merged_lm_pruned)
attn_weights = self.cross_attn_weights( attn_weights = self.cross_attn_weights(
merged_lm_pruned, merged_am_pruned, pos_emb merged_lm_pruned, merged_am_pruned, pos_emb
) )
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights) # (num_heads, b_p_dim, lm_seq_len, am_seq_len)
# (T, batch_size * prune_range, encoder_dim) # print("attn_weights.shape", attn_weights.shape)
label_level_am_representation = self.cross_attn(
merged_am_pruned, attn_weights
)
# print(
# "label_level_am_representation.shape",
# label_level_am_representation.shape,
# )
# (lm_seq_len, batch_size * prune_range, encoder_dim)
return label_level_am_representation.reshape( return label_level_am_representation.reshape(
T, batch_size, prune_range, encoder_dim lm_T, batch_size, prune_range, encoder_dim
).permute(1, 0, 2, 3) ).permute(1, 0, 2, 3)
elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3:
# am_pruned : [am_T, B, encoder_dim]
# lm_pruned : [lm_T, B, decoder_dim]
(am_T, batch_size, encoder_dim) = am_pruned.shape
(lm_T, batch_size, decoder_dim) = lm_pruned.shape
pos_emb = self.pos_encode(lm_pruned)
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb)
label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
# (T, batch_size, encoder_dim)
return label_level_am_representation
else:
raise NotImplementedError("Dim Error")
if __name__ == "__main__": if __name__ == "__main__":
# am_pruned : [B, T, prune_range, encoder_dim] attn = AlignmentAttentionModule()
# lm_pruned : [B, T, prune_range, decoder_dim]
# am_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512)
# lm_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512)
# # am_pruned : [B * prune_range, T, encoder_dim] print("__main__ === for inference ===")
# # lm_pruned : [B * prune_range, T, decoder_dim] # am : [T, B, encoder_dim]
# lm : [1, B, decoder_dim]
# pos_emb = torch.rand(1, 19, 192) am = torch.rand(100, 2, 512)
lm = torch.rand(1, 2, 512)
# q / K separate seq_len
# weights = RelPositionMultiheadAttentionWeights() # weights = RelPositionMultiheadAttentionWeights()
# attn = SelfAttention(512, 5, 12) # attn = CrossAttention(512, 5, 12)
# attn_weights = weights(lm, am, pos_emb)
# attn_weights = weights(lm_pruned, am_pruned, pos_emb)
# print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape) # print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape)
# res = attn(am_pruned, attn_weights) # res = attn(am, attn_weights)
# print("res", res.shape) res = attn(am, lm)
print("__main__ res", res.shape)
print("__main__ === for training ===")
# am_pruned : [B, T, prune_range, encoder_dim] # am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned = torch.rand(2, 100, 5, 512) am_pruned = torch.rand(2, 100, 5, 512)
lm_pruned = torch.rand(2, 100, 5, 512) lm_pruned = torch.rand(2, 100, 5, 512)
attn = AlignmentAttentionModule()
res = attn(am_pruned, lm_pruned) res = attn(am_pruned, lm_pruned)
print("res", res.shape) print("__main__ res", res.shape)

View File

@ -15,7 +15,142 @@ from scaling import (
softmax, softmax,
) )
from torch import Tensor, nn from torch import Tensor, nn
from zipformer import CompactRelPositionalEncoding, SelfAttention, _whitening_schedule from zipformer import CompactRelPositionalEncoding, _whitening_schedule
from icefall.utils import make_pad_mask
class CrossAttention(nn.Module):
"""
The simplest possible attention module. This one works with already-computed attention
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
Args:
embed_dim: the input and output embedding dimension
num_heads: the number of attention heads
value_head_dim: the value dimension per head
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
value_head_dim: int,
) -> None:
super().__init__()
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
self.out_proj = ScaledLinear(
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
)
self.whiten = Whiten(
num_groups=1,
whitening_limit=_whitening_schedule(7.5, ratio=3.0),
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(
self,
x: Tensor,
attn_weights: Tensor,
) -> Tensor:
"""
Args:
x: input tensor, of shape (seq_len, batch_size, embed_dim)
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
attn_weights.sum(dim=-1) == 1.
Returns:
a tensor with the same shape as x.
"""
(am_seq_len, batch_size, embed_dim) = x.shape
(_, _, lm_seq_len, _) = attn_weights.shape
num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, lm_seq_len, am_seq_len)
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
# print("projected x.shape", x.shape)
x = x.reshape(am_seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, am_seq_len, value_head_dim)
# print("permuted x.shape", x.shape)
value_head_dim = x.shape[-1]
# todo: see whether there is benefit in overriding matmul
x = torch.matmul(attn_weights, x)
# v: (num_heads, batch_size, lm_seq_len, value_head_dim)
# print("attended x.shape", x.shape)
x = (
x.permute(2, 1, 0, 3)
.contiguous()
.view(lm_seq_len, batch_size, num_heads * value_head_dim)
)
# returned value is of shape (lm_seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x)
x = self.whiten(x)
# print("returned x.shape", x.shape)
return x
def streaming_forward(
self,
x: Tensor,
attn_weights: Tensor,
cached_val: Tensor,
left_context_len: int,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x: input tensor, of shape (seq_len, batch_size, embed_dim)
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
attn_weights.sum(dim=-1) == 1.
cached_val: cached attention value tensor of left context,
of shape (left_context_len, batch_size, value_dim)
left_context_len: number of left context frames.
Returns:
- attention weighted output, a tensor with the same shape as x.
- updated cached attention value tensor of left context.
"""
(seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0]
seq_len2 = seq_len + left_context_len
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
# Pad cached left contexts
assert cached_val.shape[0] == left_context_len, (
cached_val.shape[0],
left_context_len,
)
x = torch.cat([cached_val, x], dim=0)
# Update cached left contexts
cached_val = x[-left_context_len:, ...]
x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, seq_len, value_head_dim)
value_head_dim = x.shape[-1]
# todo: see whether there is benefit in overriding matmul
x = torch.matmul(attn_weights, x)
# v: (num_heads, batch_size, seq_len, value_head_dim)
x = (
x.permute(2, 1, 0, 3)
.contiguous()
.view(seq_len, batch_size, num_heads * value_head_dim)
)
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
x = self.out_proj(x)
return x, cached_val
class RelPositionMultiheadAttentionWeights(nn.Module): class RelPositionMultiheadAttentionWeights(nn.Module):
@ -40,7 +175,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
def __init__( def __init__(
self, self,
embed_dim: int = 512, lm_embed_dim: int = 512,
am_embed_dim: int = 512,
pos_dim: int = 192, pos_dim: int = 192,
num_heads: int = 5, num_heads: int = 5,
query_head_dim: int = 32, query_head_dim: int = 32,
@ -49,7 +185,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.lm_embed_dim = lm_embed_dim
self.am_embed_dim = am_embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.query_head_dim = query_head_dim self.query_head_dim = query_head_dim
self.pos_head_dim = pos_head_dim self.pos_head_dim = pos_head_dim
@ -67,10 +204,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# to be used with the ScaledAdam optimizer; with most other optimizers, # to be used with the ScaledAdam optimizer; with most other optimizers,
# it would be necessary to apply the scaling factor in the forward function. # it would be necessary to apply the scaling factor in the forward function.
self.in_lm_proj = ScaledLinear( self.in_lm_proj = ScaledLinear(
embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25 lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25
) )
self.in_am_proj = ScaledLinear( self.in_am_proj = ScaledLinear(
embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25 am_embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25
) )
self.whiten_keys = Whiten( self.whiten_keys = Whiten(
@ -117,10 +254,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
) -> Tensor: ) -> Tensor:
r""" r"""
Args: Args:
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) lm_pruned: input of shape (lm_seq_len, batch_size * prune_range, decoder_embed_dim)
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) am_pruned: input of shape (am_seq_len, batch_size * prune_range, encoder_embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) pos_emb: Positional embedding tensor, of shape (1, 2*lm_seq_len - 1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions key_padding_mask: a bool tensor of shape (batch_size * prune_range, am_seq_len). Positions
that are True in this mask will be ignored as sources in the attention weighting. that are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) attn_mask: mask of shape (seq_len, seq_len)
or (seq_len, batch_size * prune_range, batch_size * prune_range), or (seq_len, batch_size * prune_range, batch_size * prune_range),
@ -135,52 +272,47 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim = self.query_head_dim query_head_dim = self.query_head_dim
pos_head_dim = self.pos_head_dim pos_head_dim = self.pos_head_dim
num_heads = self.num_heads num_heads = self.num_heads
print(
"query_head_dim",
query_head_dim,
"pos_head_dim",
pos_head_dim,
"num_heads",
num_heads,
)
( (
seq_len, lm_seq_len,
b_p_dim, b_p_dim,
_, _,
) = lm_pruned.shape # actual dim: (batch * prune_range, seq_len, _) ) = lm_pruned.shape # actual dim: (lm_seq_len, batch * prune_range, _)
(
am_seq_len,
_,
_,
) = am_pruned.shape
query_dim = query_head_dim * num_heads query_dim = query_head_dim * num_heads
# self-attention # self-attention
q = lm_pruned[..., 0:query_dim] # (batch * prune_range, seq_len, query_dim) q = lm_pruned[..., 0:query_dim] # (lm_seq_len, batch * prune_range, query_dim)
k = am_pruned # (batch * prune_range, seq_len, query_dim) k = am_pruned # (am_seq_len, batch * prune_range, query_dim)
# p is the position-encoding query # p is the position-encoding query
p = lm_pruned[ p = lm_pruned[
..., query_dim: ..., query_dim:
] # (batch * prune_range, seq_len, pos_head_dim * num_heads) ] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads)
assert p.shape[-1] == num_heads * pos_head_dim assert p.shape[-1] == num_heads * pos_head_dim
q = self.copy_query(q) # for diagnostics only, does nothing. q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
p = self.copy_pos_query(p) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing.
q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim) q = q.reshape(lm_seq_len, b_p_dim, num_heads, query_head_dim)
p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim) p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim)
k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim) k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim)
print("q.shape after reshape", q.shape)
print("p.shape after reshape", p.shape)
print("k.shape after reshape", k.shape)
# time1 refers to target, time2 refers to source. # time1 refers to target (query: lm), time2 refers to source (key: am).
q = q.permute( q = q.permute(
2, 1, 0, 3 2, 1, 0, 3
) # (head, seq_len, batch * prune_range, query_head_dim) ) # (head, batch * prune_range, lm_seq_len, query_head_dim)
p = p.permute(2, 1, 0, 3) # (head, seq_len, batch * prune_range, pos_head_dim) p = p.permute(
k = k.permute(2, 1, 3, 0) # (head, seq_len, d_k, batch * prune_range) 2, 1, 0, 3
) # (head, batch * prune_range, lm_seq_len, pos_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, am_seq_len)
attn_scores = torch.matmul(q, k) attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
print("attn_scores", attn_scores.shape)
use_pos_scores = False use_pos_scores = False
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -191,14 +323,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if use_pos_scores: if use_pos_scores:
pos_emb = self.linear_pos(pos_emb) pos_emb = self.linear_pos(pos_emb)
print("pos_emb before proj", pos_emb.shape) seq_len2 = 2 * lm_seq_len - 1
seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
2, 0, 3, 1 2, 0, 3, 1
) )
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
print("p", p.shape)
print("pos_emb after proj", pos_emb.shape)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.] # [where seq_len2 represents relative position.]
@ -209,24 +338,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if torch.jit.is_tracing(): if torch.jit.is_tracing():
(num_heads, b_p_dim, time1, n) = pos_scores.shape (num_heads, b_p_dim, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1) rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len) cols = torch.arange(lm_seq_len)
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
indexes = rows + cols indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n) pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes) pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len) pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
else: else:
pos_scores = pos_scores.as_strided( pos_scores = pos_scores.as_strided(
(num_heads, b_p_dim, seq_len, seq_len), (num_heads, b_p_dim, lm_seq_len, lm_seq_len),
( (
pos_scores.stride(0), pos_scores.stride(0),
pos_scores.stride(1), pos_scores.stride(1),
pos_scores.stride(2) - pos_scores.stride(3), pos_scores.stride(2) - pos_scores.stride(3),
pos_scores.stride(3), pos_scores.stride(3),
), ),
storage_offset=pos_scores.stride(3) * (seq_len - 1), storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
) )
attn_scores = attn_scores + pos_scores attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -247,8 +375,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = penalize_abs_values_gt( attn_scores = penalize_abs_values_gt(
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
) )
assert attn_scores.shape == (num_heads, b_p_dim, lm_seq_len, am_seq_len)
assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len)
if attn_mask is not None: if attn_mask is not None:
assert attn_mask.dtype == torch.bool assert attn_mask.dtype == torch.bool
@ -261,7 +388,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape == ( assert key_padding_mask.shape == (
b_p_dim, b_p_dim,
seq_len, am_seq_len,
), key_padding_mask.shape ), key_padding_mask.shape
attn_scores = attn_scores.masked_fill( attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1), key_padding_mask.unsqueeze(1),
@ -286,7 +413,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
return attn_weights return attn_weights
def _print_attn_entropy(self, attn_weights: Tensor): def _print_attn_entropy(self, attn_weights: Tensor):
# attn_weights: (num_heads, batch_size, seq_len, seq_len) # attn_weights: (num_heads, batch_size, lm_seq_len, am_seq_len)
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad(): with torch.no_grad():
@ -322,7 +449,7 @@ class AlignmentAttentionModule(nn.Module):
pos_head_dim=pos_head_dim, pos_head_dim=pos_head_dim,
dropout=dropout, dropout=dropout,
) )
self.cross_attn = SelfAttention( self.cross_attn = CrossAttention(
embed_dim=embed_dim, embed_dim=embed_dim,
num_heads=num_heads, num_heads=num_heads,
value_head_dim=value_head_dim, value_head_dim=value_head_dim,
@ -332,52 +459,80 @@ class AlignmentAttentionModule(nn.Module):
) )
def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor: def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor:
# am_pruned : [B, T, prune_range, encoder_dim] if len(am_pruned.shape) == 4 and len(lm_pruned.shape) == 4:
# lm_pruned : [B, T, prune_range, decoder_dim] # src_key_padding_mask = make_pad_mask(am_pruned_lens)
(batch_size, T, prune_range, encoder_dim) = am_pruned.shape
(batch_size, T, prune_range, decoder_dim) = lm_pruned.shape
# am_pruned : [B * prune_range, T, encoder_dim] # am_pruned : [B, am_T, prune_range, encoder_dim]
# lm_pruned : [B * prune_range, T, decoder_dim] # lm_pruned : [B, lm_T, prune_range, decoder_dim]
am_pruned = am_pruned.permute(1, 0, 2, 3).reshape( (batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape
T, batch_size * prune_range, encoder_dim (batch_size, lm_T, prune_range, decoder_dim) = lm_pruned.shape
)
lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
T, batch_size * prune_range, decoder_dim
)
pos_emb = self.pos_encode(am_pruned) # merged_am_pruned : [am_T, B * prune_range, encoder_dim]
print("input pos_emb.shape", pos_emb.shape) # merged_lm_pruned : [lm_T, B * prune_range, decoder_dim]
merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape(
am_T, batch_size * prune_range, encoder_dim
)
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
lm_T, batch_size * prune_range, decoder_dim
)
pos_emb = self.pos_encode(merged_lm_pruned)
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb) attn_weights = self.cross_attn_weights(
label_level_am_representation = self.cross_attn(am_pruned, attn_weights) merged_lm_pruned, merged_am_pruned, pos_emb
return label_level_am_representation )
# (num_heads, b_p_dim, lm_seq_len, am_seq_len)
# print("attn_weights.shape", attn_weights.shape)
label_level_am_representation = self.cross_attn(
merged_am_pruned, attn_weights
)
# print(
# "label_level_am_representation.shape",
# label_level_am_representation.shape,
# )
# (lm_seq_len, batch_size * prune_range, encoder_dim)
return label_level_am_representation.reshape(
lm_T, batch_size, prune_range, encoder_dim
).permute(1, 0, 2, 3)
elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3:
# am_pruned : [am_T, B, encoder_dim]
# lm_pruned : [lm_T, B, decoder_dim]
(am_T, batch_size, encoder_dim) = am_pruned.shape
(lm_T, batch_size, decoder_dim) = lm_pruned.shape
pos_emb = self.pos_encode(lm_pruned)
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb)
label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
# (T, batch_size, encoder_dim)
return label_level_am_representation
else:
raise NotImplementedError("Dim Error")
if __name__ == "__main__": if __name__ == "__main__":
# am_pruned : [B, T, prune_range, encoder_dim] attn = AlignmentAttentionModule()
# lm_pruned : [B, T, prune_range, decoder_dim]
# am_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512)
# lm_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512)
# # am_pruned : [B * prune_range, T, encoder_dim] print("__main__ === for inference ===")
# # lm_pruned : [B * prune_range, T, decoder_dim] # am : [T, B, encoder_dim]
# lm : [1, B, decoder_dim]
# pos_emb = torch.rand(1, 19, 192) am = torch.rand(100, 2, 512)
lm = torch.rand(1, 2, 512)
# q / K separate seq_len
# weights = RelPositionMultiheadAttentionWeights() # weights = RelPositionMultiheadAttentionWeights()
# attn = SelfAttention(512, 5, 12) # attn = CrossAttention(512, 5, 12)
# attn_weights = weights(lm, am, pos_emb)
# attn_weights = weights(lm_pruned, am_pruned, pos_emb)
# print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape) # print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape)
# res = attn(am_pruned, attn_weights) # res = attn(am, attn_weights)
# print("res", res.shape) res = attn(am, lm)
print("__main__ res", res.shape)
print("__main__ === for training ===")
# am_pruned : [B, T, prune_range, encoder_dim] # am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim] # lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned = torch.rand(2, 100, 5, 512) am_pruned = torch.rand(2, 100, 5, 512)
lm_pruned = torch.rand(2, 100, 5, 512) lm_pruned = torch.rand(2, 100, 5, 512)
attn = AlignmentAttentionModule()
res = attn(am_pruned, lm_pruned) res = attn(am_pruned, lm_pruned)
print("res", res.shape) print("__main__ res", res.shape)

View File

@ -39,6 +39,7 @@ class Joiner(nn.Module):
self, self,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
decoder_out: torch.Tensor, decoder_out: torch.Tensor,
apply_attn: bool = True,
project_input: bool = True, project_input: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -47,7 +48,9 @@ class Joiner(nn.Module):
Output from the encoder. Its shape is (N, T, s_range, C). Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out: decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C). Output from the decoder. Its shape is (N, T, s_range, C).
project_input: encoder_out_lens:
Encoder output lengths, of shape (N,).
project_input:
If true, apply input projections encoder_proj and decoder_proj. If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this If this is false, it is the user's responsibility to do this
manually. manually.
@ -59,7 +62,8 @@ class Joiner(nn.Module):
decoder_out.shape, decoder_out.shape,
) )
encoder_out = self.label_level_am_attention(encoder_out, decoder_out) if apply_attn:
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)