minor updates

This commit is contained in:
JinZr 2023-08-11 10:19:55 +08:00
parent 96b7c7aecf
commit 29f2228675
4 changed files with 37 additions and 3073 deletions

File diff suppressed because it is too large Load Diff

View File

@ -260,6 +260,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_emb: Tensor,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
for_reference: bool = False,
offset: Optional[int] = None,
) -> Tensor:
r"""
Args:
@ -276,6 +278,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
a tensor of attention weights, of shape
(num_heads, seq_len, batch_size * prune_range, batch_size * prune_range)
"""
assert not for_reference or (
for_reference and offset is not None
), f"for_reference: {for_reference}"
lm_pruned = self.in_lm_proj(lm_pruned) # lm_pruned as query
am_pruned = self.in_am_proj(am_pruned) # am_pruned as key
query_head_dim = self.query_head_dim
@ -326,6 +331,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, am_seq_len)
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
# print(attn_scores.shape)
# use_pos_scores = False
# if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -336,15 +342,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# if use_pos_scores:
# pos_emb = self.linear_pos(pos_emb)
# seq_len2 = 2 * lm_seq_len - 1
# seq_len2 = 2 * am_seq_len - 1
# pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
# 2, 0, 3, 1
# )
# # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# if for_reference:
# pos_offset = am_seq_len - 1 - offset
# pos_emb = pos_emb[
# :,
# :,
# :,
# pos_offset : pos_offset + am_seq_len,
# ]
# # (head, 1, pos_dim, seq_len2)
# # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# # [where seq_len2 represents relative position.]
# pos_scores = torch.matmul(p, pos_emb)
# # print(pos_scores.shape, attn_scores.shape)
# # the following .as_strided() expression converts the last axis of pos_scores from relative
# # to absolute position. I don't know whether I might have got the time-offsets backwards or
# # not, but let this code define which way round it is supposed to be.
@ -368,6 +385,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# ),
# storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
# )
# # print(pos_scores.shape)
# attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -496,11 +514,20 @@ class AlignmentAttentionModule(nn.Module):
self.dropout = nn.Dropout(p=0.5)
def forward(
self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor
self,
am_pruned: Tensor,
lm_pruned: Tensor,
lengths: torch.Tensor,
for_reference: bool = False,
offset: Optional[int] = None,
) -> Tensor:
src_key_padding_mask = make_pad_mask(lengths) if lengths is not None else None
# (batch, max_len)
assert not for_reference or (
for_reference and offset is not None
), f"for_reference: {for_reference}"
if am_pruned.ndim == 4 and lm_pruned.ndim == 4:
# am_pruned : [B, am_T, prune_range, encoder_dim]
# lm_pruned : [B, lm_T, prune_range, decoder_dim]
@ -515,17 +542,20 @@ class AlignmentAttentionModule(nn.Module):
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
lm_T, batch_size * prune_range, decoder_dim
)
pos_emb = self.pos_encode(merged_lm_pruned)
pos_emb = self.pos_encode(merged_am_pruned)
attn_weights = self.cross_attn_weights(
merged_lm_pruned,
merged_am_pruned,
pos_emb,
key_padding_mask=src_key_padding_mask,
for_reference=for_reference,
offset=offset,
)
# (num_heads, b_p_dim, lm_seq_len, am_seq_len)
label_level_am_representation = self.cross_attn(
merged_am_pruned, attn_weights
merged_am_pruned,
attn_weights,
)
return label_level_am_representation.reshape(

View File

@ -67,7 +67,7 @@ class Joiner(nn.Module):
decoder_out.shape,
)
if apply_attn:
if apply_attn and attn_encoder_out is None:
if not self.enable_attn:
self.enable_attn = True
logging.info("enabling ATTN!")
@ -79,7 +79,7 @@ class Joiner(nn.Module):
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
if apply_attn:
# print(attn_encoder_out)
# print(torch.mean(attn_encoder_out, dim=0))
logit = encoder_out + decoder_out + attn_encoder_out
else:
logging.info("disabling cross attn mdl")

View File

@ -84,7 +84,7 @@ class AsrModel(nn.Module):
self.encoder_embed = encoder_embed
self.encoder = encoder
self.dropout = nn.Dropout(p=0.5)
self.dropout = nn.Dropout(p=0.1)
self.use_transducer = use_transducer
if use_transducer: