mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Add cat to zipformer
This commit is contained in:
parent
0ad037d076
commit
aaec7c299f
@ -133,7 +133,6 @@ class Zipformer2(EncoderInterface):
|
||||
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
||||
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
|
||||
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
|
||||
pos_head_dim = _to_tuple(pos_head_dim)
|
||||
@ -259,7 +258,7 @@ class Zipformer2(EncoderInterface):
|
||||
if not self.causal:
|
||||
return -1, -1
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
assert len(self.chunk_size) == 1, self.chunk_size
|
||||
chunk_size = self.chunk_size[0]
|
||||
else:
|
||||
@ -268,7 +267,7 @@ class Zipformer2(EncoderInterface):
|
||||
if chunk_size == -1:
|
||||
left_context_chunks = -1
|
||||
else:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
assert len(self.left_context_frames) == 1, self.left_context_frames
|
||||
left_context_frames = self.left_context_frames[0]
|
||||
else:
|
||||
@ -302,14 +301,14 @@ class Zipformer2(EncoderInterface):
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
outputs = []
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
feature_masks = [1.0] * len(self.encoder_dim)
|
||||
else:
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
# Not support exporting a model for simulating streaming decoding
|
||||
attn_mask = None
|
||||
else:
|
||||
@ -335,7 +334,7 @@ class Zipformer2(EncoderInterface):
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
lengths = (x_lens + 1) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
@ -373,7 +372,7 @@ class Zipformer2(EncoderInterface):
|
||||
# t is frame index, shape (seq_len,)
|
||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||
# c is chunk index for each frame, shape (seq_len,)
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
c = t // chunk_size
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
@ -545,13 +544,15 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
||||
) -> None:
|
||||
super(Zipformer2EncoderLayer, self).__init__()
|
||||
embed_dim = embed_dim >> 1
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||
self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate,
|
||||
self.bypass = BypassModule(embed_dim * 2, skip_rate=bypass_skip_rate,
|
||||
straight_through_rate=0)
|
||||
# bypass_mid is bypass used in the middle of the layer.
|
||||
self.bypass_mid = BypassModule(embed_dim, straight_through_rate=0)
|
||||
self.bypass_mid_a = BypassModule(embed_dim, straight_through_rate=0)
|
||||
self.bypass_mid_b = BypassModule(embed_dim, straight_through_rate=0)
|
||||
|
||||
# skip probability for dynamic modules (meaning: anything but feedforward).
|
||||
self.attention_skip_rate = copy.deepcopy(attention_skip_rate)
|
||||
@ -566,48 +567,71 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
self.const_attention_rate = copy.deepcopy(const_attention_rate)
|
||||
|
||||
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
|
||||
self.cross_attn_weights_a = RelPositionMultiheadAttentionWeights(
|
||||
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
||||
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
|
||||
dropout=0.0,
|
||||
)
|
||||
self.cross_attn_weights_b = RelPositionMultiheadAttentionWeights(
|
||||
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
|
||||
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
self.self_attn1 = SelfAttention(embed_dim, num_heads,
|
||||
self.cross_attn1_a = CrossAttention(embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
self.cross_attn1_b = CrossAttention(embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
self.self_attn2 = SelfAttention(embed_dim, num_heads,
|
||||
self.cross_attn2_a = CrossAttention(embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
self.cross_attn2_b = CrossAttention(embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(embed_dim,
|
||||
self.feed_forward1_a = FeedforwardModule(embed_dim,
|
||||
(feedforward_dim * 3) // 4,
|
||||
dropout)
|
||||
self.feed_forward1_b = FeedforwardModule(embed_dim,
|
||||
(feedforward_dim * 3) // 4,
|
||||
dropout)
|
||||
|
||||
self.feed_forward2 = FeedforwardModule(embed_dim,
|
||||
self.feed_forward2_a = FeedforwardModule(embed_dim,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
self.feed_forward2_b = FeedforwardModule(embed_dim,
|
||||
feedforward_dim,
|
||||
dropout)
|
||||
|
||||
self.feed_forward3 = FeedforwardModule(embed_dim,
|
||||
self.feed_forward3_a = FeedforwardModule(embed_dim,
|
||||
(feedforward_dim * 5) // 4,
|
||||
dropout)
|
||||
self.feed_forward3_b = FeedforwardModule(embed_dim,
|
||||
(feedforward_dim * 5) // 4,
|
||||
dropout)
|
||||
|
||||
self.nonlin_attention = NonlinAttention(embed_dim,
|
||||
self.nonlin_attention_a = NonlinAttention(embed_dim,
|
||||
hidden_channels=3 * embed_dim // 4)
|
||||
self.nonlin_attention_b = NonlinAttention(embed_dim,
|
||||
hidden_channels=3 * embed_dim // 4)
|
||||
|
||||
self.conv_module1 = ConvolutionModule(embed_dim,
|
||||
self.conv_module1_a = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
causal=causal)
|
||||
self.conv_module1_b = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
causal=causal)
|
||||
|
||||
self.conv_module2 = ConvolutionModule(embed_dim,
|
||||
self.conv_module2_a = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
causal=causal)
|
||||
self.conv_module2_b = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
causal=causal)
|
||||
|
||||
# TODO: remove it
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
|
||||
self.norm = BiasNorm(embed_dim)
|
||||
self.norm = BiasNorm(embed_dim * 2)
|
||||
|
||||
self.balancer1 = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
embed_dim * 2, channel_dim=-1,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
min_abs=0.2, max_abs=4.0,
|
||||
)
|
||||
@ -645,13 +669,13 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
grad_scale=0.01)
|
||||
|
||||
self.balancer2 = Balancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
embed_dim * 2, channel_dim=-1,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
min_abs=0.1, max_abs=4.0,
|
||||
)
|
||||
|
||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||
return None
|
||||
batch_size = x.shape[1]
|
||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||
@ -693,86 +717,120 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
Returns:
|
||||
A tensor which has the same shape as src
|
||||
"""
|
||||
src_orig = src
|
||||
src_a, src_b = torch.split(src, self.embed_dim, 2)
|
||||
src_orig_a, src_orig_b = src_a, src_b
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
attention_skip_rate = 0.0
|
||||
else:
|
||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
attn_weights = self.self_attn_weights(
|
||||
src,
|
||||
attn_weights_a = self.cross_attn_weights_a(
|
||||
src_a,
|
||||
src_b,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
attn_weights_b = self.cross_attn_weights_b(
|
||||
src_b,
|
||||
src_a,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
src_a = src_a + self.feed_forward1_a(src_a)
|
||||
src_b = src_b + self.feed_forward1_b(src_b)
|
||||
|
||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
cross_attn_dropout_mask = self.get_sequence_dropout_mask(src_a, attention_skip_rate)
|
||||
|
||||
selected_attn_weights = attn_weights[0:1]
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
selected_attn_weights_a = attn_weights_a[0:1]
|
||||
selected_attn_weights_b = attn_weights_b[0:1]
|
||||
if torch.jit.is_scripting():
|
||||
pass
|
||||
elif not self.training and random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
# encourage these modules to do something similar to an
|
||||
# averaging-over-time operation.
|
||||
# only need the mask, can just use the 1st one and expand later
|
||||
selected_attn_weights = selected_attn_weights[0:1]
|
||||
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||
selected_attn_weights_a = selected_attn_weights_a[0:1]
|
||||
selected_attn_weights_b = selected_attn_weights_b[0:1]
|
||||
selected_attn_weights_a = (selected_attn_weights_a > 0.0).to(selected_attn_weights_a.dtype)
|
||||
selected_attn_weights_b = (selected_attn_weights_b > 0.0).to(selected_attn_weights_b.dtype)
|
||||
selected_attn_weights_a = selected_attn_weights_a * (1.0 / selected_attn_weights_a.sum(dim=-1, keepdim=True))
|
||||
selected_attn_weights_b = selected_attn_weights_b * (1.0 / selected_attn_weights_b.sum(dim=-1, keepdim=True))
|
||||
|
||||
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
|
||||
na_a = self.balancer_na(self.nonlin_attention_a(src_b, selected_attn_weights_a))
|
||||
na_b = self.balancer_na(self.nonlin_attention_b(src_a, selected_attn_weights_b))
|
||||
|
||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||
src_a = src_a + (na_a if cross_attn_dropout_mask is None else na_a * cross_attn_dropout_mask)
|
||||
src_b = src_b + (na_b if cross_attn_dropout_mask is None else na_b * cross_attn_dropout_mask)
|
||||
|
||||
self_attn = self.self_attn1(src, attn_weights)
|
||||
cross_attn_a = self.cross_attn1_a(src_b, attn_weights_a)
|
||||
cross_attn_b = self.cross_attn1_b(src_a, attn_weights_b)
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
src_a = src_a + (cross_attn_a if cross_attn_dropout_mask is None else cross_attn_a * cross_attn_dropout_mask)
|
||||
src_b = src_b + (cross_attn_b if cross_attn_dropout_mask is None else cross_attn_b * cross_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
|
||||
src_a = src_a + self.sequence_dropout(self.conv_module1_a(src_a, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
src_b = src_b + self.sequence_dropout(self.conv_module1_b(src_b, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
ff2_skip_rate = 0.0
|
||||
else:
|
||||
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
||||
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
|
||||
src_a = src_a + self.sequence_dropout(self.balancer_ff2(self.feed_forward2_a(src_a)),
|
||||
ff2_skip_rate)
|
||||
src_b = src_b + self.sequence_dropout(self.balancer_ff2(self.feed_forward2_b(src_b)),
|
||||
ff2_skip_rate)
|
||||
|
||||
# bypass in the middle of the layer.
|
||||
src = self.bypass_mid(src_orig, src)
|
||||
src_a = self.bypass_mid_a(src_orig_a, src_a)
|
||||
src_b = self.bypass_mid_b(src_orig_b, src_b)
|
||||
|
||||
self_attn = self.self_attn2(src, attn_weights)
|
||||
cross_attn_a = self.cross_attn2_a(src_b, attn_weights_a)
|
||||
cross_attn_b = self.cross_attn2_b(src_a, attn_weights_b)
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
src_a = src_a + (cross_attn_a if cross_attn_dropout_mask is None else cross_attn_a * cross_attn_dropout_mask)
|
||||
src_b = src_b + (cross_attn_b if cross_attn_dropout_mask is None else cross_attn_b * cross_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
|
||||
src_a = src_a + self.sequence_dropout(self.conv_module2_a(src_a, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
src_b = src_b + self.sequence_dropout(self.conv_module2_b(src_b, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
ff3_skip_rate = 0.0
|
||||
else:
|
||||
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
||||
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
|
||||
src_a = src_a + self.sequence_dropout(self.balancer_ff3(self.feed_forward3_a(src_a)),
|
||||
ff3_skip_rate)
|
||||
src_b = src_b + self.sequence_dropout(self.balancer_ff3(self.feed_forward3_b(src_b)),
|
||||
ff3_skip_rate)
|
||||
|
||||
src = torch.cat([src_a, src_b], 2)
|
||||
|
||||
src = self.balancer1(src)
|
||||
src = self.norm(src)
|
||||
|
||||
src_orig = torch.cat([src_orig_a, src_orig_b], 2)
|
||||
src = self.bypass(src_orig, src)
|
||||
|
||||
src = self.balancer2(src)
|
||||
@ -828,7 +886,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_orig = src
|
||||
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
attn_weights, cached_key = self.self_attn_weights.streaming_forward(
|
||||
attn_weights, cached_key = self.cross_attn_weights.streaming_forward(
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
cached_key=cached_key,
|
||||
@ -846,13 +904,13 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
)
|
||||
src = src + na
|
||||
|
||||
self_attn, cached_val1 = self.self_attn1.streaming_forward(
|
||||
cross_attn, cached_val1 = self.cross_attn1.streaming_forward(
|
||||
src,
|
||||
attn_weights=attn_weights,
|
||||
cached_val=cached_val1,
|
||||
left_context_len=left_context_len,
|
||||
)
|
||||
src = src + self_attn
|
||||
src = src + cross_attn
|
||||
|
||||
src_conv, cached_conv1 = self.conv_module1.streaming_forward(
|
||||
src,
|
||||
@ -866,13 +924,13 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
# bypass in the middle of the layer.
|
||||
src = self.bypass_mid(src_orig, src)
|
||||
|
||||
self_attn, cached_val2 = self.self_attn2.streaming_forward(
|
||||
cross_attn, cached_val2 = self.cross_attn2.streaming_forward(
|
||||
src,
|
||||
attn_weights=attn_weights,
|
||||
cached_val=cached_val2,
|
||||
left_context_len=left_context_len,
|
||||
)
|
||||
src = src + self_attn
|
||||
src = src + cross_attn
|
||||
|
||||
src_conv, cached_conv2 = self.conv_module2.streaming_forward(
|
||||
src,
|
||||
@ -969,7 +1027,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
pos_emb = self.encoder_pos(src)
|
||||
output = src
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
if not torch.jit.is_scripting():
|
||||
output = output * feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
@ -981,7 +1039,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
if not torch.jit.is_scripting():
|
||||
output = output * feature_mask
|
||||
|
||||
return output
|
||||
@ -1074,7 +1132,7 @@ class BypassModule(nn.Module):
|
||||
# or (batch_size, num_channels,). This is actually the
|
||||
# scale on the non-residual term, so 0 correponds to bypassing
|
||||
# this module.
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
return self.bypass_scale
|
||||
else:
|
||||
ans = limit_param_value(self.bypass_scale,
|
||||
@ -1230,11 +1288,12 @@ class SimpleDownsample(torch.nn.Module):
|
||||
d_seq_len = (seq_len + ds - 1) // ds
|
||||
|
||||
# Pad to an exact multiple of self.downsample
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||
|
||||
@ -1322,7 +1381,11 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(0) >= T * 2 - 1:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
|
||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||
@ -1434,7 +1497,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# dividing it between the query and key. Note: this module is intended
|
||||
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
||||
# it would be necessary to apply the scaling factor in the forward function.
|
||||
self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True,
|
||||
self.in_proj_a = ScaledLinear(embed_dim, in_proj_dim, bias=True,
|
||||
initial_scale=query_head_dim**-0.25)
|
||||
self.in_proj_b = ScaledLinear(embed_dim, in_proj_dim, bias=True,
|
||||
initial_scale=query_head_dim**-0.25)
|
||||
|
||||
self.whiten_keys = Whiten(num_groups=num_heads,
|
||||
@ -1471,6 +1536,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
y: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
@ -1478,6 +1544,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
r"""
|
||||
Args:
|
||||
x: input of shape (seq_len, batch_size, embed_dim)
|
||||
y: input of shape (seq_len, batch_size, embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
||||
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
||||
are True in this mask will be ignored as sources in the attention weighting.
|
||||
@ -1488,7 +1555,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len)
|
||||
interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len).
|
||||
"""
|
||||
x = self.in_proj(x)
|
||||
x = self.in_proj_a(x)
|
||||
y = self.in_proj_b(y)
|
||||
query_head_dim = self.query_head_dim
|
||||
pos_head_dim = self.pos_head_dim
|
||||
num_heads = self.num_heads
|
||||
@ -1499,7 +1567,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
# self-attention
|
||||
q = x[...,0:query_dim]
|
||||
k = x[...,query_dim:2*query_dim]
|
||||
k = y[...,query_dim:2*query_dim]
|
||||
# p is the position-encoding query
|
||||
p = x[...,2*query_dim:]
|
||||
assert p.shape[-1] == num_heads * pos_head_dim
|
||||
@ -1520,7 +1588,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
use_pos_scores = False
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
# We can't put random.random() in the same line
|
||||
use_pos_scores = True
|
||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
@ -1538,26 +1606,16 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(seq_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
pass
|
||||
elif self.training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be
|
||||
@ -1600,7 +1658,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# half-precision output for backprop purposes.
|
||||
attn_weights = softmax(attn_scores, dim=-1)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
if torch.jit.is_scripting():
|
||||
pass
|
||||
elif random.random() < 0.001 and not self.training:
|
||||
self._print_attn_entropy(attn_weights)
|
||||
@ -1678,26 +1736,15 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
pos_scores = torch.matmul(p, pos_emb)
|
||||
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(k_len)
|
||||
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
@ -1728,7 +1775,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
The simplest possible attention module. This one works with already-computed attention
|
||||
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
|
||||
@ -2153,7 +2200,7 @@ class ConvolutionModule(nn.Module):
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
|
||||
if not torch.jit.is_scripting() and chunk_size >= 0:
|
||||
# Not support exporting a model for simulated streaming decoding
|
||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
||||
|
Loading…
x
Reference in New Issue
Block a user