mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
issues fixed
This commit is contained in:
parent
739e2a22c6
commit
49e9d15733
@ -15,7 +15,142 @@ from scaling import (
|
||||
softmax,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from zipformer import CompactRelPositionalEncoding, SelfAttention, _whitening_schedule
|
||||
from zipformer import CompactRelPositionalEncoding, _whitening_schedule
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
The simplest possible attention module. This one works with already-computed attention
|
||||
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
|
||||
|
||||
Args:
|
||||
embed_dim: the input and output embedding dimension
|
||||
num_heads: the number of attention heads
|
||||
value_head_dim: the value dimension per head
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
value_head_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
|
||||
|
||||
self.out_proj = ScaledLinear(
|
||||
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
|
||||
self.whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5, ratio=3.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
||||
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
||||
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
Returns:
|
||||
a tensor with the same shape as x.
|
||||
"""
|
||||
(am_seq_len, batch_size, embed_dim) = x.shape
|
||||
(_, _, lm_seq_len, _) = attn_weights.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
assert attn_weights.shape == (num_heads, batch_size, lm_seq_len, am_seq_len)
|
||||
|
||||
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
|
||||
# print("projected x.shape", x.shape)
|
||||
|
||||
x = x.reshape(am_seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, am_seq_len, value_head_dim)
|
||||
# print("permuted x.shape", x.shape)
|
||||
|
||||
value_head_dim = x.shape[-1]
|
||||
|
||||
# todo: see whether there is benefit in overriding matmul
|
||||
x = torch.matmul(attn_weights, x)
|
||||
# v: (num_heads, batch_size, lm_seq_len, value_head_dim)
|
||||
# print("attended x.shape", x.shape)
|
||||
|
||||
x = (
|
||||
x.permute(2, 1, 0, 3)
|
||||
.contiguous()
|
||||
.view(lm_seq_len, batch_size, num_heads * value_head_dim)
|
||||
)
|
||||
|
||||
# returned value is of shape (lm_seq_len, batch_size, embed_dim), like the input.
|
||||
x = self.out_proj(x)
|
||||
x = self.whiten(x)
|
||||
# print("returned x.shape", x.shape)
|
||||
|
||||
return x
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
cached_val: Tensor,
|
||||
left_context_len: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
||||
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
||||
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
cached_val: cached attention value tensor of left context,
|
||||
of shape (left_context_len, batch_size, value_dim)
|
||||
left_context_len: number of left context frames.
|
||||
|
||||
Returns:
|
||||
- attention weighted output, a tensor with the same shape as x.
|
||||
- updated cached attention value tensor of left context.
|
||||
"""
|
||||
(seq_len, batch_size, embed_dim) = x.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
seq_len2 = seq_len + left_context_len
|
||||
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
|
||||
|
||||
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
||||
|
||||
# Pad cached left contexts
|
||||
assert cached_val.shape[0] == left_context_len, (
|
||||
cached_val.shape[0],
|
||||
left_context_len,
|
||||
)
|
||||
x = torch.cat([cached_val, x], dim=0)
|
||||
# Update cached left contexts
|
||||
cached_val = x[-left_context_len:, ...]
|
||||
|
||||
x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
value_head_dim = x.shape[-1]
|
||||
|
||||
# todo: see whether there is benefit in overriding matmul
|
||||
x = torch.matmul(attn_weights, x)
|
||||
# v: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
|
||||
x = (
|
||||
x.permute(2, 1, 0, 3)
|
||||
.contiguous()
|
||||
.view(seq_len, batch_size, num_heads * value_head_dim)
|
||||
)
|
||||
|
||||
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x, cached_val
|
||||
|
||||
|
||||
class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
@ -40,7 +175,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 512,
|
||||
lm_embed_dim: int = 512,
|
||||
am_embed_dim: int = 512,
|
||||
pos_dim: int = 192,
|
||||
num_heads: int = 5,
|
||||
query_head_dim: int = 32,
|
||||
@ -49,7 +185,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.lm_embed_dim = lm_embed_dim
|
||||
self.am_embed_dim = am_embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.query_head_dim = query_head_dim
|
||||
self.pos_head_dim = pos_head_dim
|
||||
@ -67,10 +204,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
||||
# it would be necessary to apply the scaling factor in the forward function.
|
||||
self.in_lm_proj = ScaledLinear(
|
||||
embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
)
|
||||
self.in_am_proj = ScaledLinear(
|
||||
embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
am_embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
)
|
||||
|
||||
self.whiten_keys = Whiten(
|
||||
@ -117,10 +254,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
|
||||
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
||||
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
|
||||
lm_pruned: input of shape (lm_seq_len, batch_size * prune_range, decoder_embed_dim)
|
||||
am_pruned: input of shape (am_seq_len, batch_size * prune_range, encoder_embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*lm_seq_len - 1, pos_dim)
|
||||
key_padding_mask: a bool tensor of shape (batch_size * prune_range, am_seq_len). Positions
|
||||
that are True in this mask will be ignored as sources in the attention weighting.
|
||||
attn_mask: mask of shape (seq_len, seq_len)
|
||||
or (seq_len, batch_size * prune_range, batch_size * prune_range),
|
||||
@ -137,38 +274,45 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
num_heads = self.num_heads
|
||||
|
||||
(
|
||||
seq_len,
|
||||
lm_seq_len,
|
||||
b_p_dim,
|
||||
_,
|
||||
) = lm_pruned.shape # actual dim: (seq_len, batch * prune_range, _)
|
||||
) = lm_pruned.shape # actual dim: (lm_seq_len, batch * prune_range, _)
|
||||
(
|
||||
am_seq_len,
|
||||
_,
|
||||
_,
|
||||
) = am_pruned.shape
|
||||
|
||||
query_dim = query_head_dim * num_heads
|
||||
|
||||
# self-attention
|
||||
q = lm_pruned[..., 0:query_dim] # (seq_len, batch * prune_range, query_dim)
|
||||
k = am_pruned # (seq_len, batch * prune_range, query_dim)
|
||||
q = lm_pruned[..., 0:query_dim] # (lm_seq_len, batch * prune_range, query_dim)
|
||||
k = am_pruned # (am_seq_len, batch * prune_range, query_dim)
|
||||
# p is the position-encoding query
|
||||
p = lm_pruned[
|
||||
..., query_dim:
|
||||
] # (seq_len, batch * prune_range, pos_head_dim * num_heads)
|
||||
] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads)
|
||||
assert p.shape[-1] == num_heads * pos_head_dim
|
||||
|
||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
||||
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
||||
|
||||
q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim)
|
||||
p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim)
|
||||
k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim)
|
||||
q = q.reshape(lm_seq_len, b_p_dim, num_heads, query_head_dim)
|
||||
p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim)
|
||||
k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim)
|
||||
|
||||
# time1 refers to target, time2 refers to source.
|
||||
# time1 refers to target (query: lm), time2 refers to source (key: am).
|
||||
q = q.permute(
|
||||
2, 1, 0, 3
|
||||
) # (head, batch * prune_range, seq_len, query_head_dim)
|
||||
p = p.permute(2, 1, 0, 3) # (head, batch * prune_range, seq_len, pos_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, seq_len)
|
||||
) # (head, batch * prune_range, lm_seq_len, query_head_dim)
|
||||
p = p.permute(
|
||||
2, 1, 0, 3
|
||||
) # (head, batch * prune_range, lm_seq_len, pos_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, am_seq_len)
|
||||
|
||||
attn_scores = torch.matmul(q, k)
|
||||
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
|
||||
|
||||
use_pos_scores = False
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
@ -179,7 +323,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
if use_pos_scores:
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
seq_len2 = 2 * lm_seq_len - 1
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||
2, 0, 3, 1
|
||||
)
|
||||
@ -194,24 +338,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, b_p_dim, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
cols = torch.arange(lm_seq_len)
|
||||
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len)
|
||||
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided(
|
||||
(num_heads, b_p_dim, seq_len, seq_len),
|
||||
(num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
||||
(
|
||||
pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2) - pos_scores.stride(3),
|
||||
pos_scores.stride(3),
|
||||
),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1),
|
||||
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||
)
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
@ -232,8 +375,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
attn_scores = penalize_abs_values_gt(
|
||||
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
|
||||
)
|
||||
|
||||
assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len)
|
||||
assert attn_scores.shape == (num_heads, b_p_dim, lm_seq_len, am_seq_len)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.dtype == torch.bool
|
||||
@ -246,7 +388,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (
|
||||
b_p_dim,
|
||||
seq_len,
|
||||
am_seq_len,
|
||||
), key_padding_mask.shape
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask.unsqueeze(1),
|
||||
@ -271,7 +413,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
return attn_weights
|
||||
|
||||
def _print_attn_entropy(self, attn_weights: Tensor):
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
# attn_weights: (num_heads, batch_size, lm_seq_len, am_seq_len)
|
||||
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
||||
|
||||
with torch.no_grad():
|
||||
@ -307,7 +449,7 @@ class AlignmentAttentionModule(nn.Module):
|
||||
pos_head_dim=pos_head_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.cross_attn = SelfAttention(
|
||||
self.cross_attn = CrossAttention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
value_head_dim=value_head_dim,
|
||||
@ -317,57 +459,80 @@ class AlignmentAttentionModule(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor:
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
(batch_size, T, prune_range, encoder_dim) = am_pruned.shape
|
||||
(batch_size, T, prune_range, decoder_dim) = lm_pruned.shape
|
||||
if len(am_pruned.shape) == 4 and len(lm_pruned.shape) == 4:
|
||||
# src_key_padding_mask = make_pad_mask(am_pruned_lens)
|
||||
|
||||
# am_pruned : [T, B * prune_range, encoder_dim]
|
||||
# lm_pruned : [T, B * prune_range, decoder_dim]
|
||||
merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape(
|
||||
T, batch_size * prune_range, encoder_dim
|
||||
)
|
||||
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
|
||||
T, batch_size * prune_range, decoder_dim
|
||||
)
|
||||
# am_pruned : [B, am_T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, lm_T, prune_range, decoder_dim]
|
||||
(batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape
|
||||
(batch_size, lm_T, prune_range, decoder_dim) = lm_pruned.shape
|
||||
|
||||
pos_emb = self.pos_encode(merged_am_pruned)
|
||||
# merged_am_pruned : [am_T, B * prune_range, encoder_dim]
|
||||
# merged_lm_pruned : [lm_T, B * prune_range, decoder_dim]
|
||||
merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape(
|
||||
am_T, batch_size * prune_range, encoder_dim
|
||||
)
|
||||
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
|
||||
lm_T, batch_size * prune_range, decoder_dim
|
||||
)
|
||||
pos_emb = self.pos_encode(merged_lm_pruned)
|
||||
|
||||
attn_weights = self.cross_attn_weights(
|
||||
merged_lm_pruned, merged_am_pruned, pos_emb
|
||||
)
|
||||
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
|
||||
# (T, batch_size * prune_range, encoder_dim)
|
||||
attn_weights = self.cross_attn_weights(
|
||||
merged_lm_pruned, merged_am_pruned, pos_emb
|
||||
)
|
||||
# (num_heads, b_p_dim, lm_seq_len, am_seq_len)
|
||||
# print("attn_weights.shape", attn_weights.shape)
|
||||
label_level_am_representation = self.cross_attn(
|
||||
merged_am_pruned, attn_weights
|
||||
)
|
||||
# print(
|
||||
# "label_level_am_representation.shape",
|
||||
# label_level_am_representation.shape,
|
||||
# )
|
||||
# (lm_seq_len, batch_size * prune_range, encoder_dim)
|
||||
|
||||
return label_level_am_representation.reshape(
|
||||
T, batch_size, prune_range, encoder_dim
|
||||
).permute(1, 0, 2, 3)
|
||||
return label_level_am_representation.reshape(
|
||||
lm_T, batch_size, prune_range, encoder_dim
|
||||
).permute(1, 0, 2, 3)
|
||||
elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3:
|
||||
# am_pruned : [am_T, B, encoder_dim]
|
||||
# lm_pruned : [lm_T, B, decoder_dim]
|
||||
(am_T, batch_size, encoder_dim) = am_pruned.shape
|
||||
(lm_T, batch_size, decoder_dim) = lm_pruned.shape
|
||||
|
||||
pos_emb = self.pos_encode(lm_pruned)
|
||||
|
||||
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb)
|
||||
label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
|
||||
# (T, batch_size, encoder_dim)
|
||||
|
||||
return label_level_am_representation
|
||||
else:
|
||||
raise NotImplementedError("Dim Error")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
# am_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512)
|
||||
# lm_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512)
|
||||
attn = AlignmentAttentionModule()
|
||||
|
||||
# # am_pruned : [B * prune_range, T, encoder_dim]
|
||||
# # lm_pruned : [B * prune_range, T, decoder_dim]
|
||||
|
||||
# pos_emb = torch.rand(1, 19, 192)
|
||||
print("__main__ === for inference ===")
|
||||
# am : [T, B, encoder_dim]
|
||||
# lm : [1, B, decoder_dim]
|
||||
am = torch.rand(100, 2, 512)
|
||||
lm = torch.rand(1, 2, 512)
|
||||
# q / K separate seq_len
|
||||
|
||||
# weights = RelPositionMultiheadAttentionWeights()
|
||||
# attn = SelfAttention(512, 5, 12)
|
||||
|
||||
# attn_weights = weights(lm_pruned, am_pruned, pos_emb)
|
||||
# attn = CrossAttention(512, 5, 12)
|
||||
# attn_weights = weights(lm, am, pos_emb)
|
||||
# print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape)
|
||||
# res = attn(am_pruned, attn_weights)
|
||||
# print("res", res.shape)
|
||||
# res = attn(am, attn_weights)
|
||||
res = attn(am, lm)
|
||||
print("__main__ res", res.shape)
|
||||
|
||||
print("__main__ === for training ===")
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned = torch.rand(2, 100, 5, 512)
|
||||
lm_pruned = torch.rand(2, 100, 5, 512)
|
||||
|
||||
attn = AlignmentAttentionModule()
|
||||
res = attn(am_pruned, lm_pruned)
|
||||
print("res", res.shape)
|
||||
print("__main__ res", res.shape)
|
||||
|
@ -15,7 +15,142 @@ from scaling import (
|
||||
softmax,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
from zipformer import CompactRelPositionalEncoding, SelfAttention, _whitening_schedule
|
||||
from zipformer import CompactRelPositionalEncoding, _whitening_schedule
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
The simplest possible attention module. This one works with already-computed attention
|
||||
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
|
||||
|
||||
Args:
|
||||
embed_dim: the input and output embedding dimension
|
||||
num_heads: the number of attention heads
|
||||
value_head_dim: the value dimension per head
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
value_head_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
|
||||
|
||||
self.out_proj = ScaledLinear(
|
||||
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
|
||||
self.whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5, ratio=3.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
||||
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
||||
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
Returns:
|
||||
a tensor with the same shape as x.
|
||||
"""
|
||||
(am_seq_len, batch_size, embed_dim) = x.shape
|
||||
(_, _, lm_seq_len, _) = attn_weights.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
assert attn_weights.shape == (num_heads, batch_size, lm_seq_len, am_seq_len)
|
||||
|
||||
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
|
||||
# print("projected x.shape", x.shape)
|
||||
|
||||
x = x.reshape(am_seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, am_seq_len, value_head_dim)
|
||||
# print("permuted x.shape", x.shape)
|
||||
|
||||
value_head_dim = x.shape[-1]
|
||||
|
||||
# todo: see whether there is benefit in overriding matmul
|
||||
x = torch.matmul(attn_weights, x)
|
||||
# v: (num_heads, batch_size, lm_seq_len, value_head_dim)
|
||||
# print("attended x.shape", x.shape)
|
||||
|
||||
x = (
|
||||
x.permute(2, 1, 0, 3)
|
||||
.contiguous()
|
||||
.view(lm_seq_len, batch_size, num_heads * value_head_dim)
|
||||
)
|
||||
|
||||
# returned value is of shape (lm_seq_len, batch_size, embed_dim), like the input.
|
||||
x = self.out_proj(x)
|
||||
x = self.whiten(x)
|
||||
# print("returned x.shape", x.shape)
|
||||
|
||||
return x
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
cached_val: Tensor,
|
||||
left_context_len: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
||||
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
||||
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
cached_val: cached attention value tensor of left context,
|
||||
of shape (left_context_len, batch_size, value_dim)
|
||||
left_context_len: number of left context frames.
|
||||
|
||||
Returns:
|
||||
- attention weighted output, a tensor with the same shape as x.
|
||||
- updated cached attention value tensor of left context.
|
||||
"""
|
||||
(seq_len, batch_size, embed_dim) = x.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
seq_len2 = seq_len + left_context_len
|
||||
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len2)
|
||||
|
||||
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
||||
|
||||
# Pad cached left contexts
|
||||
assert cached_val.shape[0] == left_context_len, (
|
||||
cached_val.shape[0],
|
||||
left_context_len,
|
||||
)
|
||||
x = torch.cat([cached_val, x], dim=0)
|
||||
# Update cached left contexts
|
||||
cached_val = x[-left_context_len:, ...]
|
||||
|
||||
x = x.reshape(seq_len2, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
value_head_dim = x.shape[-1]
|
||||
|
||||
# todo: see whether there is benefit in overriding matmul
|
||||
x = torch.matmul(attn_weights, x)
|
||||
# v: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
|
||||
x = (
|
||||
x.permute(2, 1, 0, 3)
|
||||
.contiguous()
|
||||
.view(seq_len, batch_size, num_heads * value_head_dim)
|
||||
)
|
||||
|
||||
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x, cached_val
|
||||
|
||||
|
||||
class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
@ -40,7 +175,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 512,
|
||||
lm_embed_dim: int = 512,
|
||||
am_embed_dim: int = 512,
|
||||
pos_dim: int = 192,
|
||||
num_heads: int = 5,
|
||||
query_head_dim: int = 32,
|
||||
@ -49,7 +185,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.lm_embed_dim = lm_embed_dim
|
||||
self.am_embed_dim = am_embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.query_head_dim = query_head_dim
|
||||
self.pos_head_dim = pos_head_dim
|
||||
@ -67,10 +204,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
||||
# it would be necessary to apply the scaling factor in the forward function.
|
||||
self.in_lm_proj = ScaledLinear(
|
||||
embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
)
|
||||
self.in_am_proj = ScaledLinear(
|
||||
embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
am_embed_dim, in_am_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
)
|
||||
|
||||
self.whiten_keys = Whiten(
|
||||
@ -117,10 +254,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
|
||||
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
||||
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
|
||||
lm_pruned: input of shape (lm_seq_len, batch_size * prune_range, decoder_embed_dim)
|
||||
am_pruned: input of shape (am_seq_len, batch_size * prune_range, encoder_embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*lm_seq_len - 1, pos_dim)
|
||||
key_padding_mask: a bool tensor of shape (batch_size * prune_range, am_seq_len). Positions
|
||||
that are True in this mask will be ignored as sources in the attention weighting.
|
||||
attn_mask: mask of shape (seq_len, seq_len)
|
||||
or (seq_len, batch_size * prune_range, batch_size * prune_range),
|
||||
@ -135,52 +272,47 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
query_head_dim = self.query_head_dim
|
||||
pos_head_dim = self.pos_head_dim
|
||||
num_heads = self.num_heads
|
||||
print(
|
||||
"query_head_dim",
|
||||
query_head_dim,
|
||||
"pos_head_dim",
|
||||
pos_head_dim,
|
||||
"num_heads",
|
||||
num_heads,
|
||||
)
|
||||
|
||||
(
|
||||
seq_len,
|
||||
lm_seq_len,
|
||||
b_p_dim,
|
||||
_,
|
||||
) = lm_pruned.shape # actual dim: (batch * prune_range, seq_len, _)
|
||||
) = lm_pruned.shape # actual dim: (lm_seq_len, batch * prune_range, _)
|
||||
(
|
||||
am_seq_len,
|
||||
_,
|
||||
_,
|
||||
) = am_pruned.shape
|
||||
|
||||
query_dim = query_head_dim * num_heads
|
||||
|
||||
# self-attention
|
||||
q = lm_pruned[..., 0:query_dim] # (batch * prune_range, seq_len, query_dim)
|
||||
k = am_pruned # (batch * prune_range, seq_len, query_dim)
|
||||
q = lm_pruned[..., 0:query_dim] # (lm_seq_len, batch * prune_range, query_dim)
|
||||
k = am_pruned # (am_seq_len, batch * prune_range, query_dim)
|
||||
# p is the position-encoding query
|
||||
p = lm_pruned[
|
||||
..., query_dim:
|
||||
] # (batch * prune_range, seq_len, pos_head_dim * num_heads)
|
||||
] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads)
|
||||
assert p.shape[-1] == num_heads * pos_head_dim
|
||||
|
||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
||||
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
||||
|
||||
q = q.reshape(seq_len, b_p_dim, num_heads, query_head_dim)
|
||||
p = p.reshape(seq_len, b_p_dim, num_heads, pos_head_dim)
|
||||
k = k.reshape(seq_len, b_p_dim, num_heads, query_head_dim)
|
||||
print("q.shape after reshape", q.shape)
|
||||
print("p.shape after reshape", p.shape)
|
||||
print("k.shape after reshape", k.shape)
|
||||
q = q.reshape(lm_seq_len, b_p_dim, num_heads, query_head_dim)
|
||||
p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim)
|
||||
k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim)
|
||||
|
||||
# time1 refers to target, time2 refers to source.
|
||||
# time1 refers to target (query: lm), time2 refers to source (key: am).
|
||||
q = q.permute(
|
||||
2, 1, 0, 3
|
||||
) # (head, seq_len, batch * prune_range, query_head_dim)
|
||||
p = p.permute(2, 1, 0, 3) # (head, seq_len, batch * prune_range, pos_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, seq_len, d_k, batch * prune_range)
|
||||
) # (head, batch * prune_range, lm_seq_len, query_head_dim)
|
||||
p = p.permute(
|
||||
2, 1, 0, 3
|
||||
) # (head, batch * prune_range, lm_seq_len, pos_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, am_seq_len)
|
||||
|
||||
attn_scores = torch.matmul(q, k)
|
||||
print("attn_scores", attn_scores.shape)
|
||||
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
|
||||
|
||||
use_pos_scores = False
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
@ -191,14 +323,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
if use_pos_scores:
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
print("pos_emb before proj", pos_emb.shape)
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
seq_len2 = 2 * lm_seq_len - 1
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||
2, 0, 3, 1
|
||||
)
|
||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
print("p", p.shape)
|
||||
print("pos_emb after proj", pos_emb.shape)
|
||||
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
@ -209,24 +338,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, b_p_dim, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
cols = torch.arange(lm_seq_len)
|
||||
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, seq_len)
|
||||
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided(
|
||||
(num_heads, b_p_dim, seq_len, seq_len),
|
||||
(num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
||||
(
|
||||
pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2) - pos_scores.stride(3),
|
||||
pos_scores.stride(3),
|
||||
),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1),
|
||||
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||
)
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
@ -247,8 +375,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
attn_scores = penalize_abs_values_gt(
|
||||
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
|
||||
)
|
||||
|
||||
assert attn_scores.shape == (num_heads, b_p_dim, seq_len, seq_len)
|
||||
assert attn_scores.shape == (num_heads, b_p_dim, lm_seq_len, am_seq_len)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.dtype == torch.bool
|
||||
@ -261,7 +388,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (
|
||||
b_p_dim,
|
||||
seq_len,
|
||||
am_seq_len,
|
||||
), key_padding_mask.shape
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask.unsqueeze(1),
|
||||
@ -286,7 +413,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
return attn_weights
|
||||
|
||||
def _print_attn_entropy(self, attn_weights: Tensor):
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
# attn_weights: (num_heads, batch_size, lm_seq_len, am_seq_len)
|
||||
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
||||
|
||||
with torch.no_grad():
|
||||
@ -322,7 +449,7 @@ class AlignmentAttentionModule(nn.Module):
|
||||
pos_head_dim=pos_head_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.cross_attn = SelfAttention(
|
||||
self.cross_attn = CrossAttention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
value_head_dim=value_head_dim,
|
||||
@ -332,52 +459,80 @@ class AlignmentAttentionModule(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor:
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
(batch_size, T, prune_range, encoder_dim) = am_pruned.shape
|
||||
(batch_size, T, prune_range, decoder_dim) = lm_pruned.shape
|
||||
if len(am_pruned.shape) == 4 and len(lm_pruned.shape) == 4:
|
||||
# src_key_padding_mask = make_pad_mask(am_pruned_lens)
|
||||
|
||||
# am_pruned : [B * prune_range, T, encoder_dim]
|
||||
# lm_pruned : [B * prune_range, T, decoder_dim]
|
||||
am_pruned = am_pruned.permute(1, 0, 2, 3).reshape(
|
||||
T, batch_size * prune_range, encoder_dim
|
||||
)
|
||||
lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
|
||||
T, batch_size * prune_range, decoder_dim
|
||||
)
|
||||
# am_pruned : [B, am_T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, lm_T, prune_range, decoder_dim]
|
||||
(batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape
|
||||
(batch_size, lm_T, prune_range, decoder_dim) = lm_pruned.shape
|
||||
|
||||
pos_emb = self.pos_encode(am_pruned)
|
||||
print("input pos_emb.shape", pos_emb.shape)
|
||||
# merged_am_pruned : [am_T, B * prune_range, encoder_dim]
|
||||
# merged_lm_pruned : [lm_T, B * prune_range, decoder_dim]
|
||||
merged_am_pruned = am_pruned.permute(1, 0, 2, 3).reshape(
|
||||
am_T, batch_size * prune_range, encoder_dim
|
||||
)
|
||||
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
|
||||
lm_T, batch_size * prune_range, decoder_dim
|
||||
)
|
||||
pos_emb = self.pos_encode(merged_lm_pruned)
|
||||
|
||||
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb)
|
||||
label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
|
||||
return label_level_am_representation
|
||||
attn_weights = self.cross_attn_weights(
|
||||
merged_lm_pruned, merged_am_pruned, pos_emb
|
||||
)
|
||||
# (num_heads, b_p_dim, lm_seq_len, am_seq_len)
|
||||
# print("attn_weights.shape", attn_weights.shape)
|
||||
label_level_am_representation = self.cross_attn(
|
||||
merged_am_pruned, attn_weights
|
||||
)
|
||||
# print(
|
||||
# "label_level_am_representation.shape",
|
||||
# label_level_am_representation.shape,
|
||||
# )
|
||||
# (lm_seq_len, batch_size * prune_range, encoder_dim)
|
||||
|
||||
return label_level_am_representation.reshape(
|
||||
lm_T, batch_size, prune_range, encoder_dim
|
||||
).permute(1, 0, 2, 3)
|
||||
elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3:
|
||||
# am_pruned : [am_T, B, encoder_dim]
|
||||
# lm_pruned : [lm_T, B, decoder_dim]
|
||||
(am_T, batch_size, encoder_dim) = am_pruned.shape
|
||||
(lm_T, batch_size, decoder_dim) = lm_pruned.shape
|
||||
|
||||
pos_emb = self.pos_encode(lm_pruned)
|
||||
|
||||
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb)
|
||||
label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
|
||||
# (T, batch_size, encoder_dim)
|
||||
|
||||
return label_level_am_representation
|
||||
else:
|
||||
raise NotImplementedError("Dim Error")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
# am_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512)
|
||||
# lm_pruned = torch.rand(2, 100, 5, 512).transpose(1, 0).reshape(2 * 5, 100, 512)
|
||||
attn = AlignmentAttentionModule()
|
||||
|
||||
# # am_pruned : [B * prune_range, T, encoder_dim]
|
||||
# # lm_pruned : [B * prune_range, T, decoder_dim]
|
||||
|
||||
# pos_emb = torch.rand(1, 19, 192)
|
||||
print("__main__ === for inference ===")
|
||||
# am : [T, B, encoder_dim]
|
||||
# lm : [1, B, decoder_dim]
|
||||
am = torch.rand(100, 2, 512)
|
||||
lm = torch.rand(1, 2, 512)
|
||||
# q / K separate seq_len
|
||||
|
||||
# weights = RelPositionMultiheadAttentionWeights()
|
||||
# attn = SelfAttention(512, 5, 12)
|
||||
|
||||
# attn_weights = weights(lm_pruned, am_pruned, pos_emb)
|
||||
# attn = CrossAttention(512, 5, 12)
|
||||
# attn_weights = weights(lm, am, pos_emb)
|
||||
# print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape)
|
||||
# res = attn(am_pruned, attn_weights)
|
||||
# print("res", res.shape)
|
||||
# res = attn(am, attn_weights)
|
||||
res = attn(am, lm)
|
||||
print("__main__ res", res.shape)
|
||||
|
||||
print("__main__ === for training ===")
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned = torch.rand(2, 100, 5, 512)
|
||||
lm_pruned = torch.rand(2, 100, 5, 512)
|
||||
|
||||
attn = AlignmentAttentionModule()
|
||||
res = attn(am_pruned, lm_pruned)
|
||||
print("res", res.shape)
|
||||
print("__main__ res", res.shape)
|
||||
|
@ -39,6 +39,7 @@ class Joiner(nn.Module):
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
apply_attn: bool = True,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -47,7 +48,9 @@ class Joiner(nn.Module):
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
project_input:
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
@ -59,7 +62,8 @@ class Joiner(nn.Module):
|
||||
decoder_out.shape,
|
||||
)
|
||||
|
||||
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
|
||||
if apply_attn:
|
||||
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
|
Loading…
x
Reference in New Issue
Block a user