Cosmetic fixes

This commit is contained in:
Daniel Povey 2022-11-09 17:14:18 +08:00
parent f8210e1d80
commit 308059edba

View File

@ -105,7 +105,8 @@ class Zipformer(EncoderInterface):
super(Zipformer, self).__init__()
def _to_tuple(x):
""" Converts a single int or a 1-tuple of an int to a tuple with the same length as output_downsampling_factor"""
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
as downsampling_factor"""
if isinstance(x, int):
x = (x,)
if len(x) == 1:
@ -152,15 +153,15 @@ class Zipformer(EncoderInterface):
num_encoders = len(downsampling_factor)
for i in range(num_encoders):
encoder_layer = ZipformerEncoderLayer(
encoder_dim[i],
pos_dim,
num_heads[i],
query_head_dim[i],
pos_head_dim[i],
value_head_dim[i],
feedforward_dim[i],
dropout,
cnn_module_kernel[i],
embed_dim=encoder_dim[i],
pos_dim=pos_dim,
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
pos_head_dim=pos_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i],
)
# For the segment of the warmup period, we let the Conv2dSubsampling
@ -168,8 +169,8 @@ class Zipformer(EncoderInterface):
encoder = ZipformerEncoder(
encoder_layer,
num_encoder_layers[i],
pos_dim,
dropout,
pos_dim=pos_dim,
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1)
)
@ -372,8 +373,9 @@ class ZipformerEncoderLayer(nn.Module):
self.batch_count = 0
self.self_attn_weights = RelPositionMultiheadAttentionWeights(
embed_dim, pos_dim, num_heads,
query_head_dim, pos_head_dim, dropout=0.0,
embed_dim, pos_dim=pos_dim, num_heads=num_heads,
query_head_dim=query_head_dim, pos_head_dim=pos_head_dim,
dropout=0.0,
)
self.self_attn1 = SelfAttention(embed_dim, num_heads,
@ -1099,17 +1101,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
bias=False,
initial_scale=0.05)
# this is to stop a failure mode where the output gets very small and is
# dominated by the mean (the min_positive and max_positive will stop the mean
# being much larger than the variance). Make min_abs very small because even for normal,
# functional self_attn layers, the output rms can be very small.
self.out_balancer = ActivationBalancer(embed_dim,
channel_dim=-1,
min_positive=0.33,
max_positive=0.66,
min_abs=0.005, max_abs=1.0,
min_prob=0.05)
# the following are for diagnosics only, see --print-diagnostics option
self.copy_pos_query = Identity()
@ -1141,8 +1132,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim = self.query_head_dim
pos_head_dim = self.pos_head_dim
num_heads = self.num_heads
dropout = self.dropout,
training = self.training,
seq_len, batch_size, _ = x.shape
@ -1165,6 +1154,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
# time1 refers to target, time2 refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
@ -1189,7 +1179,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k) + pos_scores
if training and random.random() < 0.1:
if self.training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be
# too large. It incurs a penalty if any of them has an absolute
# value greater than 50.0. this should be outside the normal range
@ -1214,10 +1204,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
attn_scores = attn_scores.view(
num_heads, batch_size, seq_len, seq_len
)
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1),
float("-inf"),