Try to implement test mode; fix issue where middle stack had not been
downsampled.
This commit is contained in:
parent
30ace76fbc
commit
53410608a6
@ -161,11 +161,9 @@ class Subformer(EncoderInterface):
|
|||||||
mid = len(encoders) // 2
|
mid = len(encoders) // 2
|
||||||
encoder = DownsampledSubformerEncoder(
|
encoder = DownsampledSubformerEncoder(
|
||||||
[ encoders[mid] ],
|
[ encoders[mid] ],
|
||||||
input_num_channels=encoder_dim[mid],
|
input_num_channels=encoder_dim[mid-1],
|
||||||
downsample=2
|
downsample=2
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder = encoders[mid]
|
|
||||||
for i in range(1, mid+1):
|
for i in range(1, mid+1):
|
||||||
this_list = [ encoders[mid-i],
|
this_list = [ encoders[mid-i],
|
||||||
encoder,
|
encoder,
|
||||||
@ -670,8 +668,7 @@ class SubformerEncoder(nn.Module):
|
|||||||
chunk_indexes: a list of indexes into chunk_sizes, one per layer.
|
chunk_indexes: a list of indexes into chunk_sizes, one per layer.
|
||||||
"""
|
"""
|
||||||
seq_len = src.shape[0]
|
seq_len = src.shape[0]
|
||||||
assert seq_len < self.chunk_size or seq_len % self.chunk_size == 0
|
if seq_len <= self.chunk_size or seq_len % self.chunk_size != 0:
|
||||||
if seq_len <= self.chunk_size:
|
|
||||||
return [ seq_len ], [ 0 ] * len(self.layers)
|
return [ seq_len ], [ 0 ] * len(self.layers)
|
||||||
else:
|
else:
|
||||||
assert seq_len % self.chunk_size == 0, (seq_len, self.chunk_size)
|
assert seq_len % self.chunk_size == 0, (seq_len, self.chunk_size)
|
||||||
@ -828,8 +825,8 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
# these drifting around.
|
# these drifting around.
|
||||||
# largish range used to keep grads relatively small and avoid overflow in grads.
|
# largish range used to keep grads relatively small and avoid overflow in grads.
|
||||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||||
min_positive=0.4, max_positive=0.6,
|
min_positive=1/(2*downsampling_factor),
|
||||||
min_abs=1.0, max_abs=1.2)
|
min_abs=1.0)
|
||||||
|
|
||||||
self.copy_weights1 = nn.Identity()
|
self.copy_weights1 = nn.Identity()
|
||||||
self.copy_weights2 = nn.Identity()
|
self.copy_weights2 = nn.Identity()
|
||||||
@ -863,50 +860,73 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
# sscores, indexes: (batch_size, seq_len)
|
# sscores, indexes: (batch_size, seq_len)
|
||||||
sscores, indexes = scores.sort(dim=-1, descending=True)
|
sscores, indexes = scores.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
d = self.downsampling_factor
|
|
||||||
seq_len_reduced = (seq_len + d - 1) // d
|
|
||||||
|
|
||||||
# TODO: if seq_len / downsampling_factor <= 2, do something special.
|
if self.training:
|
||||||
|
d = self.downsampling_factor
|
||||||
|
|
||||||
intermediate_rate = float(self.intermediate_rate)
|
seq_len_reduced = (seq_len + d - 1) // d
|
||||||
|
|
||||||
# 'right' is the rightmost of the 2 limits; we want the scores indexed
|
intermediate_rate = float(self.intermediate_rate)
|
||||||
# 'upper' to be mapped to around 0.0
|
|
||||||
right = seq_len_reduced
|
|
||||||
# we want scores around 'left' to be mapped to around 1.0.
|
|
||||||
left = int(seq_len_reduced * (1.0 - intermediate_rate))
|
|
||||||
|
|
||||||
# 'collar' determines the range of positions in the sorted list that we use to
|
# 'right' is the rightmost of the 2 limits; we want the scores indexed
|
||||||
# compute the average. We could let collar be 0.0, which would more exactly
|
# 'upper' to be mapped to around 0.0
|
||||||
# accomplish what we want; but we don't, because this would cause too-noisy
|
right = seq_len_reduced
|
||||||
# gradients, with too much gradient going to one frame.
|
# we want scores around 'left' to be mapped to around 1.0.
|
||||||
collar = max(1, int(seq_len_reduced * 0.5 * intermediate_rate))
|
left = int(seq_len_reduced * (1.0 - intermediate_rate))
|
||||||
|
|
||||||
# right_avg: shape (batch_size,), this is to be mapped to 0.0
|
# 'collar' determines the range of positions in the sorted list that we use to
|
||||||
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
|
# compute the average. We could let collar be 0.0, which would more exactly
|
||||||
|
# accomplish what we want; but we don't, because this would cause too-noisy
|
||||||
|
# gradients, with too much gradient going to one frame.
|
||||||
|
collar = max(1, int(seq_len_reduced * 0.5 * intermediate_rate))
|
||||||
|
|
||||||
# left_avg: shape (batch_size,), this is to be mapped to 1.0
|
# right_avg: shape (batch_size,), this is to be mapped to 0.0
|
||||||
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
|
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# the + 0.001 is to avoid possible division by zero in case of ties.
|
# we only shift the scores left (decrease them, to ensure no more than `intermediate_rate`
|
||||||
sscores = self.copy_weights1(sscores)
|
# proportion of the scores are >0). This lets us have batch-independence in test-mode,
|
||||||
|
# the idea is that the model will "learn" the right distribution of scores.
|
||||||
|
right_avg_clamped = right_avg.clamp(min=0.0)
|
||||||
|
|
||||||
|
# left_avg: shape (batch_size,), this is to be mapped to 1.0
|
||||||
|
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# the + 0.001 is to avoid possible division by zero in case of ties.
|
||||||
|
sscores = self.copy_weights1(sscores)
|
||||||
|
|
||||||
|
# divide by den: only decrease the scores' value.
|
||||||
|
den = (left_avg - right_avg_clamped).clamp(min=1.0)
|
||||||
|
|
||||||
|
#logging.info(f"den = {den}")
|
||||||
|
weights = (sscores - right_avg_clamped) / den
|
||||||
|
else:
|
||||||
|
# in test mode, no normalization (we can't have batch-dependent
|
||||||
|
# effects because this would be "seeing the future"). But we trainin such
|
||||||
|
# a way that, hopefully, it will most of the time give us not much more
|
||||||
|
# nonzero scores than in training time.
|
||||||
|
weights = sscores
|
||||||
|
|
||||||
den = (left_avg - right_avg)
|
|
||||||
# the following is to avoid division by near-zero.
|
|
||||||
den = 0.75 * den + 0.25 * den.mean()
|
|
||||||
|
|
||||||
#logging.info(f"den = {den}")
|
|
||||||
weights = (sscores - right_avg) / den
|
|
||||||
weights = weights.clamp(min=0.0, max=1.0)
|
weights = weights.clamp(min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if not self.training:
|
||||||
|
# need to work out seq_len_reduced.
|
||||||
|
seq_len_reduced = max(1,
|
||||||
|
(weights > 0.0).to(torch.int32).sum(dim=-1).max().item())
|
||||||
|
|
||||||
|
|
||||||
indexes = indexes[:, :seq_len_reduced]
|
indexes = indexes[:, :seq_len_reduced]
|
||||||
weights = weights[:, :seq_len_reduced]
|
weights = weights[:, :seq_len_reduced]
|
||||||
|
|
||||||
weights = self.copy_weights2(weights)
|
weights = self.copy_weights2(weights)
|
||||||
|
|
||||||
|
if random.random() < 0.01 or __name__ == '__main__':
|
||||||
|
logging.info(f"Mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
|
|
||||||
# re-sort the indexes we kept, on index value, so that
|
# re-sort the indexes we kept, on index value, so that
|
||||||
# masking for causal models will be in the correct order.
|
# masking for causal models will be in the correct order.
|
||||||
|
# (actually this may not really matter, TODO: see whether we
|
||||||
|
# can remove this??)
|
||||||
indexes, reorder = indexes.sort(dim=-1)
|
indexes, reorder = indexes.sort(dim=-1)
|
||||||
weights = torch.gather(weights, dim=-1, index=reorder)
|
weights = torch.gather(weights, dim=-1, index=reorder)
|
||||||
|
|
||||||
@ -1046,7 +1066,6 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
input_num_channels: int,
|
input_num_channels: int,
|
||||||
downsample: int):
|
downsample: int):
|
||||||
super(DownsampledSubformerEncoder, self).__init__()
|
super(DownsampledSubformerEncoder, self).__init__()
|
||||||
|
|
||||||
if downsample != 1:
|
if downsample != 1:
|
||||||
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
||||||
downsample)
|
downsample)
|
||||||
@ -1085,8 +1104,8 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
if hasattr(self, 'downsampler'):
|
if hasattr(self, 'downsampler'):
|
||||||
|
print("b")
|
||||||
indexes, weights, src = self.downsampler(src)
|
indexes, weights, src = self.downsampler(src)
|
||||||
|
|
||||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
||||||
|
|||||||
Reference in New Issue
Block a user