mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
minor updates
This commit is contained in:
parent
96b7c7aecf
commit
29f2228675
File diff suppressed because it is too large
Load Diff
@ -260,6 +260,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
for_reference: bool = False,
|
||||
offset: Optional[int] = None,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
@ -276,6 +278,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
a tensor of attention weights, of shape
|
||||
(num_heads, seq_len, batch_size * prune_range, batch_size * prune_range)
|
||||
"""
|
||||
assert not for_reference or (
|
||||
for_reference and offset is not None
|
||||
), f"for_reference: {for_reference}"
|
||||
lm_pruned = self.in_lm_proj(lm_pruned) # lm_pruned as query
|
||||
am_pruned = self.in_am_proj(am_pruned) # am_pruned as key
|
||||
query_head_dim = self.query_head_dim
|
||||
@ -326,6 +331,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch * prune_range, d_k, am_seq_len)
|
||||
|
||||
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
|
||||
# print(attn_scores.shape)
|
||||
|
||||
# use_pos_scores = False
|
||||
# if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
@ -336,15 +342,26 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
# if use_pos_scores:
|
||||
# pos_emb = self.linear_pos(pos_emb)
|
||||
# seq_len2 = 2 * lm_seq_len - 1
|
||||
|
||||
# seq_len2 = 2 * am_seq_len - 1
|
||||
# pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||
# 2, 0, 3, 1
|
||||
# )
|
||||
# # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
|
||||
# if for_reference:
|
||||
# pos_offset = am_seq_len - 1 - offset
|
||||
# pos_emb = pos_emb[
|
||||
# :,
|
||||
# :,
|
||||
# :,
|
||||
# pos_offset : pos_offset + am_seq_len,
|
||||
# ]
|
||||
# # (head, 1, pos_dim, seq_len2)
|
||||
# # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# # [where seq_len2 represents relative position.]
|
||||
# pos_scores = torch.matmul(p, pos_emb)
|
||||
# # print(pos_scores.shape, attn_scores.shape)
|
||||
# # the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# # to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# # not, but let this code define which way round it is supposed to be.
|
||||
@ -368,6 +385,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# ),
|
||||
# storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||
# )
|
||||
# # print(pos_scores.shape)
|
||||
# attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
@ -496,11 +514,20 @@ class AlignmentAttentionModule(nn.Module):
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
|
||||
def forward(
|
||||
self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor
|
||||
self,
|
||||
am_pruned: Tensor,
|
||||
lm_pruned: Tensor,
|
||||
lengths: torch.Tensor,
|
||||
for_reference: bool = False,
|
||||
offset: Optional[int] = None,
|
||||
) -> Tensor:
|
||||
src_key_padding_mask = make_pad_mask(lengths) if lengths is not None else None
|
||||
# (batch, max_len)
|
||||
|
||||
assert not for_reference or (
|
||||
for_reference and offset is not None
|
||||
), f"for_reference: {for_reference}"
|
||||
|
||||
if am_pruned.ndim == 4 and lm_pruned.ndim == 4:
|
||||
# am_pruned : [B, am_T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, lm_T, prune_range, decoder_dim]
|
||||
@ -515,17 +542,20 @@ class AlignmentAttentionModule(nn.Module):
|
||||
merged_lm_pruned = lm_pruned.permute(1, 0, 2, 3).reshape(
|
||||
lm_T, batch_size * prune_range, decoder_dim
|
||||
)
|
||||
pos_emb = self.pos_encode(merged_lm_pruned)
|
||||
pos_emb = self.pos_encode(merged_am_pruned)
|
||||
|
||||
attn_weights = self.cross_attn_weights(
|
||||
merged_lm_pruned,
|
||||
merged_am_pruned,
|
||||
pos_emb,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
for_reference=for_reference,
|
||||
offset=offset,
|
||||
)
|
||||
# (num_heads, b_p_dim, lm_seq_len, am_seq_len)
|
||||
label_level_am_representation = self.cross_attn(
|
||||
merged_am_pruned, attn_weights
|
||||
merged_am_pruned,
|
||||
attn_weights,
|
||||
)
|
||||
|
||||
return label_level_am_representation.reshape(
|
||||
|
@ -67,7 +67,7 @@ class Joiner(nn.Module):
|
||||
decoder_out.shape,
|
||||
)
|
||||
|
||||
if apply_attn:
|
||||
if apply_attn and attn_encoder_out is None:
|
||||
if not self.enable_attn:
|
||||
self.enable_attn = True
|
||||
logging.info("enabling ATTN!")
|
||||
@ -79,7 +79,7 @@ class Joiner(nn.Module):
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
|
||||
if apply_attn:
|
||||
# print(attn_encoder_out)
|
||||
# print(torch.mean(attn_encoder_out, dim=0))
|
||||
logit = encoder_out + decoder_out + attn_encoder_out
|
||||
else:
|
||||
logging.info("disabling cross attn mdl")
|
||||
|
@ -84,7 +84,7 @@ class AsrModel(nn.Module):
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
|
||||
self.use_transducer = use_transducer
|
||||
if use_transducer:
|
||||
|
Loading…
x
Reference in New Issue
Block a user