mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
update
This commit is contained in:
parent
49e9d15733
commit
90cb518398
@ -4,6 +4,7 @@ import random
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from scaling import (
|
from scaling import (
|
||||||
Balancer,
|
Balancer,
|
||||||
FloatLike,
|
FloatLike,
|
||||||
@ -67,7 +68,12 @@ class CrossAttention(nn.Module):
|
|||||||
(am_seq_len, batch_size, embed_dim) = x.shape
|
(am_seq_len, batch_size, embed_dim) = x.shape
|
||||||
(_, _, lm_seq_len, _) = attn_weights.shape
|
(_, _, lm_seq_len, _) = attn_weights.shape
|
||||||
num_heads = attn_weights.shape[0]
|
num_heads = attn_weights.shape[0]
|
||||||
assert attn_weights.shape == (num_heads, batch_size, lm_seq_len, am_seq_len)
|
assert attn_weights.shape == (
|
||||||
|
num_heads,
|
||||||
|
batch_size,
|
||||||
|
lm_seq_len,
|
||||||
|
am_seq_len,
|
||||||
|
), f"{attn_weights.shape}"
|
||||||
|
|
||||||
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
|
x = self.in_proj(x) # (am_seq_len, batch_size, num_heads * value_head_dim)
|
||||||
# print("projected x.shape", x.shape)
|
# print("projected x.shape", x.shape)
|
||||||
@ -181,6 +187,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
num_heads: int = 5,
|
num_heads: int = 5,
|
||||||
query_head_dim: int = 32,
|
query_head_dim: int = 32,
|
||||||
pos_head_dim: int = 4,
|
pos_head_dim: int = 4,
|
||||||
|
prune_range: int = 5,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -190,6 +197,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.query_head_dim = query_head_dim
|
self.query_head_dim = query_head_dim
|
||||||
self.pos_head_dim = pos_head_dim
|
self.pos_head_dim = pos_head_dim
|
||||||
|
self.prune_range = prune_range
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||||
self.name = None # will be overwritten in training code; for diagnostics.
|
self.name = None # will be overwritten in training code; for diagnostics.
|
||||||
@ -201,7 +209,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# the initial_scale is supposed to take over the "scaling" factor of
|
# the initial_scale is supposed to take over the "scaling" factor of
|
||||||
# head_dim ** -0.5 that has been used in previous forms of attention,
|
# head_dim ** -0.5 that has been used in previous forms of attention,
|
||||||
# dividing it between the query and key. Note: this module is intended
|
# dividing it between the query and key. Note: this module is intended
|
||||||
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
# to be used with the ScaledAdam optimizer; with most other optimizers
|
||||||
|
# ,
|
||||||
# it would be necessary to apply the scaling factor in the forward function.
|
# it would be necessary to apply the scaling factor in the forward function.
|
||||||
self.in_lm_proj = ScaledLinear(
|
self.in_lm_proj = ScaledLinear(
|
||||||
lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25
|
lm_embed_dim, in_lm_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||||
@ -294,6 +303,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
..., query_dim:
|
..., query_dim:
|
||||||
] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads)
|
] # (lm_seq_len, batch * prune_range, pos_head_dim * num_heads)
|
||||||
assert p.shape[-1] == num_heads * pos_head_dim
|
assert p.shape[-1] == num_heads * pos_head_dim
|
||||||
|
# print("q.shape", q.shape)
|
||||||
|
# print("p.shape", p.shape)
|
||||||
|
# print("k.shape", k.shape)
|
||||||
|
|
||||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||||
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
||||||
@ -303,7 +315,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim)
|
p = p.reshape(lm_seq_len, b_p_dim, num_heads, pos_head_dim)
|
||||||
k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim)
|
k = k.reshape(am_seq_len, b_p_dim, num_heads, query_head_dim)
|
||||||
|
|
||||||
# time1 refers to target (query: lm), time2 refers to source (key: am).
|
# time1 refers to target (query: lm), tim
|
||||||
|
# e2 refers to source (key: am).
|
||||||
q = q.permute(
|
q = q.permute(
|
||||||
2, 1, 0, 3
|
2, 1, 0, 3
|
||||||
) # (head, batch * prune_range, lm_seq_len, query_head_dim)
|
) # (head, batch * prune_range, lm_seq_len, query_head_dim)
|
||||||
@ -314,48 +327,48 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
|
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
|
||||||
|
|
||||||
use_pos_scores = False
|
# use_pos_scores = False
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
# if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
# We can't put random.random() in the same line
|
# # We can't put random.random() in the same line
|
||||||
use_pos_scores = True
|
# use_pos_scores = True
|
||||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
# elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
use_pos_scores = True
|
# use_pos_scores = True
|
||||||
|
|
||||||
if use_pos_scores:
|
# if use_pos_scores:
|
||||||
pos_emb = self.linear_pos(pos_emb)
|
# pos_emb = self.linear_pos(pos_emb)
|
||||||
seq_len2 = 2 * lm_seq_len - 1
|
# seq_len2 = 2 * lm_seq_len - 1
|
||||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
# pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||||
2, 0, 3, 1
|
# 2, 0, 3, 1
|
||||||
)
|
# )
|
||||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
# # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||||
|
|
||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# [where seq_len2 represents relative position.]
|
# # [where seq_len2 represents relative position.]
|
||||||
pos_scores = torch.matmul(p, pos_emb)
|
# pos_scores = torch.matmul(p, pos_emb)
|
||||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
# # the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# # to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# # not, but let this code define which way round it is supposed to be.
|
||||||
if torch.jit.is_tracing():
|
# if torch.jit.is_tracing():
|
||||||
(num_heads, b_p_dim, time1, n) = pos_scores.shape
|
# (num_heads, b_p_dim, time1, n) = pos_scores.shape
|
||||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
# rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
cols = torch.arange(lm_seq_len)
|
# cols = torch.arange(lm_seq_len)
|
||||||
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
# rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
||||||
indexes = rows + cols
|
# indexes = rows + cols
|
||||||
pos_scores = pos_scores.reshape(-1, n)
|
# pos_scores = pos_scores.reshape(-1, n)
|
||||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
# pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||||
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
# pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
||||||
else:
|
# else:
|
||||||
pos_scores = pos_scores.as_strided(
|
# pos_scores = pos_scores.as_strided(
|
||||||
(num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
# (num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
||||||
(
|
# (
|
||||||
pos_scores.stride(0),
|
# pos_scores.stride(0),
|
||||||
pos_scores.stride(1),
|
# pos_scores.stride(1),
|
||||||
pos_scores.stride(2) - pos_scores.stride(3),
|
# pos_scores.stride(2) - pos_scores.stride(3),
|
||||||
pos_scores.stride(3),
|
# pos_scores.stride(3),
|
||||||
),
|
# ),
|
||||||
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
# storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||||
)
|
# )
|
||||||
attn_scores = attn_scores + pos_scores
|
# attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
@ -375,7 +388,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
attn_scores = penalize_abs_values_gt(
|
attn_scores = penalize_abs_values_gt(
|
||||||
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
|
attn_scores, limit=25.0, penalty=1.0e-04, name=self.name
|
||||||
)
|
)
|
||||||
assert attn_scores.shape == (num_heads, b_p_dim, lm_seq_len, am_seq_len)
|
assert attn_scores.shape == (
|
||||||
|
num_heads,
|
||||||
|
b_p_dim,
|
||||||
|
lm_seq_len,
|
||||||
|
am_seq_len,
|
||||||
|
), attn_scores.shape
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert attn_mask.dtype == torch.bool
|
assert attn_mask.dtype == torch.bool
|
||||||
@ -386,12 +404,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
assert key_padding_mask.shape == (
|
# (batch, max_len)
|
||||||
b_p_dim,
|
|
||||||
am_seq_len,
|
key_padding_mask = (
|
||||||
), key_padding_mask.shape
|
(
|
||||||
|
key_padding_mask.unsqueeze(0)
|
||||||
|
.repeat(1, self.prune_range, 1)
|
||||||
|
.unsqueeze(2)
|
||||||
|
)
|
||||||
|
if key_padding_mask.shape[0] != attn_scores.shape[1]
|
||||||
|
else key_padding_mask.unsqueeze(0).unsqueeze(2)
|
||||||
|
)
|
||||||
|
|
||||||
attn_scores = attn_scores.masked_fill(
|
attn_scores = attn_scores.masked_fill(
|
||||||
key_padding_mask.unsqueeze(1),
|
key_padding_mask,
|
||||||
-1000,
|
-1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -438,6 +464,7 @@ class AlignmentAttentionModule(nn.Module):
|
|||||||
query_head_dim: int = 32,
|
query_head_dim: int = 32,
|
||||||
value_head_dim: int = 12,
|
value_head_dim: int = 12,
|
||||||
pos_head_dim: int = 4,
|
pos_head_dim: int = 4,
|
||||||
|
prune_range: int = 5,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -458,10 +485,13 @@ class AlignmentAttentionModule(nn.Module):
|
|||||||
embed_dim=pos_dim, dropout_rate=0.15
|
embed_dim=pos_dim, dropout_rate=0.15
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, am_pruned: Tensor, lm_pruned: Tensor) -> Tensor:
|
def forward(
|
||||||
if len(am_pruned.shape) == 4 and len(lm_pruned.shape) == 4:
|
self, am_pruned: Tensor, lm_pruned: Tensor, lengths: torch.Tensor
|
||||||
# src_key_padding_mask = make_pad_mask(am_pruned_lens)
|
) -> Tensor:
|
||||||
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
# (batch, max_len)
|
||||||
|
|
||||||
|
if am_pruned.ndim == 4 and lm_pruned.ndim == 4:
|
||||||
# am_pruned : [B, am_T, prune_range, encoder_dim]
|
# am_pruned : [B, am_T, prune_range, encoder_dim]
|
||||||
# lm_pruned : [B, lm_T, prune_range, decoder_dim]
|
# lm_pruned : [B, lm_T, prune_range, decoder_dim]
|
||||||
(batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape
|
(batch_size, am_T, prune_range, encoder_dim) = am_pruned.shape
|
||||||
@ -478,35 +508,40 @@ class AlignmentAttentionModule(nn.Module):
|
|||||||
pos_emb = self.pos_encode(merged_lm_pruned)
|
pos_emb = self.pos_encode(merged_lm_pruned)
|
||||||
|
|
||||||
attn_weights = self.cross_attn_weights(
|
attn_weights = self.cross_attn_weights(
|
||||||
merged_lm_pruned, merged_am_pruned, pos_emb
|
merged_lm_pruned,
|
||||||
|
merged_am_pruned,
|
||||||
|
pos_emb,
|
||||||
|
key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
# (num_heads, b_p_dim, lm_seq_len, am_seq_len)
|
# (num_heads, b_p_dim, lm_seq_len, am_seq_len)
|
||||||
# print("attn_weights.shape", attn_weights.shape)
|
|
||||||
label_level_am_representation = self.cross_attn(
|
label_level_am_representation = self.cross_attn(
|
||||||
merged_am_pruned, attn_weights
|
merged_am_pruned, attn_weights
|
||||||
)
|
)
|
||||||
# print(
|
|
||||||
# "label_level_am_representation.shape",
|
|
||||||
# label_level_am_representation.shape,
|
|
||||||
# )
|
|
||||||
# (lm_seq_len, batch_size * prune_range, encoder_dim)
|
|
||||||
|
|
||||||
return label_level_am_representation.reshape(
|
return label_level_am_representation.reshape(
|
||||||
lm_T, batch_size, prune_range, encoder_dim
|
lm_T, batch_size, prune_range, encoder_dim
|
||||||
).permute(1, 0, 2, 3)
|
).permute(1, 0, 2, 3)
|
||||||
elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3:
|
# elif len(am_pruned.shape) == 3 and len(lm_pruned.shape) == 3:
|
||||||
# am_pruned : [am_T, B, encoder_dim]
|
# am_pruned = am_pruned.permute(1, 0, 2)
|
||||||
# lm_pruned : [lm_T, B, decoder_dim]
|
# lm_pruned = lm_pruned.permute(1, 0, 2)
|
||||||
(am_T, batch_size, encoder_dim) = am_pruned.shape
|
|
||||||
(lm_T, batch_size, decoder_dim) = lm_pruned.shape
|
|
||||||
|
|
||||||
pos_emb = self.pos_encode(lm_pruned)
|
# # am_pruned : [am_T, B, encoder_dim]
|
||||||
|
# # lm_pruned : [lm_T, B, decoder_dim]
|
||||||
|
# (am_T, batch_size, encoder_dim) = am_pruned.shape
|
||||||
|
# (lm_T, batch_size, decoder_dim) = lm_pruned.shape
|
||||||
|
|
||||||
attn_weights = self.cross_attn_weights(lm_pruned, am_pruned, pos_emb)
|
# pos_emb = self.pos_encode(lm_pruned)
|
||||||
label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
|
|
||||||
# (T, batch_size, encoder_dim)
|
|
||||||
|
|
||||||
return label_level_am_representation
|
# attn_weights = self.cross_attn_weights(
|
||||||
|
# lm_pruned,
|
||||||
|
# am_pruned,
|
||||||
|
# pos_emb,
|
||||||
|
# key_padding_mask=src_key_padding_mask,
|
||||||
|
# )
|
||||||
|
# label_level_am_representation = self.cross_attn(am_pruned, attn_weights)
|
||||||
|
# # (T, batch_size, encoder_dim)
|
||||||
|
|
||||||
|
# return label_level_am_representation
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Dim Error")
|
raise NotImplementedError("Dim Error")
|
||||||
|
|
||||||
@ -526,7 +561,7 @@ if __name__ == "__main__":
|
|||||||
# attn_weights = weights(lm, am, pos_emb)
|
# attn_weights = weights(lm, am, pos_emb)
|
||||||
# print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape)
|
# print("weights(am_pruned, lm_pruned, pos_emb).shape", attn_weights.shape)
|
||||||
# res = attn(am, attn_weights)
|
# res = attn(am, attn_weights)
|
||||||
res = attn(am, lm)
|
res = attn(am, lm, torch.Tensor([70, 80]))
|
||||||
print("__main__ res", res.shape)
|
print("__main__ res", res.shape)
|
||||||
|
|
||||||
print("__main__ === for training ===")
|
print("__main__ === for training ===")
|
||||||
@ -534,5 +569,6 @@ if __name__ == "__main__":
|
|||||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||||
am_pruned = torch.rand(2, 100, 5, 512)
|
am_pruned = torch.rand(2, 100, 5, 512)
|
||||||
lm_pruned = torch.rand(2, 100, 5, 512)
|
lm_pruned = torch.rand(2, 100, 5, 512)
|
||||||
res = attn(am_pruned, lm_pruned)
|
lengths = Tensor([100, 100])
|
||||||
|
res = attn(am_pruned, lm_pruned, lengths)
|
||||||
print("__main__ res", res.shape)
|
print("__main__ res", res.shape)
|
||||||
|
|||||||
@ -39,6 +39,7 @@ class Joiner(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
decoder_out: torch.Tensor,
|
decoder_out: torch.Tensor,
|
||||||
|
lengths: torch.Tensor,
|
||||||
apply_attn: bool = True,
|
apply_attn: bool = True,
|
||||||
project_input: bool = True,
|
project_input: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -62,8 +63,10 @@ class Joiner(nn.Module):
|
|||||||
decoder_out.shape,
|
decoder_out.shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
if apply_attn:
|
if apply_attn and lengths is not None:
|
||||||
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
|
encoder_out = self.label_level_am_attention(
|
||||||
|
encoder_out, decoder_out, lengths
|
||||||
|
)
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
|
|||||||
@ -264,7 +264,9 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(
|
||||||
|
am_pruned, lm_pruned, encoder_out_lens, project_input=False
|
||||||
|
)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user