From 53410608a693147bec3d09fb81bdc4107da611ce Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 17 May 2023 13:02:59 +0800 Subject: [PATCH] Try to implement test mode; fix issue where middle stack had not been downsampled. --- egs/libriheavy/LM/zipformer1/subformer.py | 89 ++++++++++++++--------- 1 file changed, 54 insertions(+), 35 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index a87e0d8a9..769c2d553 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -161,11 +161,9 @@ class Subformer(EncoderInterface): mid = len(encoders) // 2 encoder = DownsampledSubformerEncoder( [ encoders[mid] ], - input_num_channels=encoder_dim[mid], + input_num_channels=encoder_dim[mid-1], downsample=2 ) - - encoder = encoders[mid] for i in range(1, mid+1): this_list = [ encoders[mid-i], encoder, @@ -670,8 +668,7 @@ class SubformerEncoder(nn.Module): chunk_indexes: a list of indexes into chunk_sizes, one per layer. """ seq_len = src.shape[0] - assert seq_len < self.chunk_size or seq_len % self.chunk_size == 0 - if seq_len <= self.chunk_size: + if seq_len <= self.chunk_size or seq_len % self.chunk_size != 0: return [ seq_len ], [ 0 ] * len(self.layers) else: assert seq_len % self.chunk_size == 0, (seq_len, self.chunk_size) @@ -828,8 +825,8 @@ class LearnedDownsamplingModule(nn.Module): # these drifting around. # largish range used to keep grads relatively small and avoid overflow in grads. self.score_balancer = Balancer(1, channel_dim=-1, - min_positive=0.4, max_positive=0.6, - min_abs=1.0, max_abs=1.2) + min_positive=1/(2*downsampling_factor), + min_abs=1.0) self.copy_weights1 = nn.Identity() self.copy_weights2 = nn.Identity() @@ -863,50 +860,73 @@ class LearnedDownsamplingModule(nn.Module): # sscores, indexes: (batch_size, seq_len) sscores, indexes = scores.sort(dim=-1, descending=True) - d = self.downsampling_factor - seq_len_reduced = (seq_len + d - 1) // d - # TODO: if seq_len / downsampling_factor <= 2, do something special. + if self.training: + d = self.downsampling_factor - intermediate_rate = float(self.intermediate_rate) + seq_len_reduced = (seq_len + d - 1) // d - # 'right' is the rightmost of the 2 limits; we want the scores indexed - # 'upper' to be mapped to around 0.0 - right = seq_len_reduced - # we want scores around 'left' to be mapped to around 1.0. - left = int(seq_len_reduced * (1.0 - intermediate_rate)) + intermediate_rate = float(self.intermediate_rate) - # 'collar' determines the range of positions in the sorted list that we use to - # compute the average. We could let collar be 0.0, which would more exactly - # accomplish what we want; but we don't, because this would cause too-noisy - # gradients, with too much gradient going to one frame. - collar = max(1, int(seq_len_reduced * 0.5 * intermediate_rate)) + # 'right' is the rightmost of the 2 limits; we want the scores indexed + # 'upper' to be mapped to around 0.0 + right = seq_len_reduced + # we want scores around 'left' to be mapped to around 1.0. + left = int(seq_len_reduced * (1.0 - intermediate_rate)) - # right_avg: shape (batch_size,), this is to be mapped to 0.0 - right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True) + # 'collar' determines the range of positions in the sorted list that we use to + # compute the average. We could let collar be 0.0, which would more exactly + # accomplish what we want; but we don't, because this would cause too-noisy + # gradients, with too much gradient going to one frame. + collar = max(1, int(seq_len_reduced * 0.5 * intermediate_rate)) - # left_avg: shape (batch_size,), this is to be mapped to 1.0 - left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True) + # right_avg: shape (batch_size,), this is to be mapped to 0.0 + right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True) - # the + 0.001 is to avoid possible division by zero in case of ties. - sscores = self.copy_weights1(sscores) + # we only shift the scores left (decrease them, to ensure no more than `intermediate_rate` + # proportion of the scores are >0). This lets us have batch-independence in test-mode, + # the idea is that the model will "learn" the right distribution of scores. + right_avg_clamped = right_avg.clamp(min=0.0) + + # left_avg: shape (batch_size,), this is to be mapped to 1.0 + left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True) + + # the + 0.001 is to avoid possible division by zero in case of ties. + sscores = self.copy_weights1(sscores) + + # divide by den: only decrease the scores' value. + den = (left_avg - right_avg_clamped).clamp(min=1.0) + + #logging.info(f"den = {den}") + weights = (sscores - right_avg_clamped) / den + else: + # in test mode, no normalization (we can't have batch-dependent + # effects because this would be "seeing the future"). But we trainin such + # a way that, hopefully, it will most of the time give us not much more + # nonzero scores than in training time. + weights = sscores - den = (left_avg - right_avg) - # the following is to avoid division by near-zero. - den = 0.75 * den + 0.25 * den.mean() - #logging.info(f"den = {den}") - weights = (sscores - right_avg) / den weights = weights.clamp(min=0.0, max=1.0) + if not self.training: + # need to work out seq_len_reduced. + seq_len_reduced = max(1, + (weights > 0.0).to(torch.int32).sum(dim=-1).max().item()) + + indexes = indexes[:, :seq_len_reduced] weights = weights[:, :seq_len_reduced] weights = self.copy_weights2(weights) + if random.random() < 0.01 or __name__ == '__main__': + logging.info(f"Mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") + # re-sort the indexes we kept, on index value, so that # masking for causal models will be in the correct order. - + # (actually this may not really matter, TODO: see whether we + # can remove this??) indexes, reorder = indexes.sort(dim=-1) weights = torch.gather(weights, dim=-1, index=reorder) @@ -1046,7 +1066,6 @@ class DownsampledSubformerEncoder(nn.Module): input_num_channels: int, downsample: int): super(DownsampledSubformerEncoder, self).__init__() - if downsample != 1: self.downsampler = LearnedDownsamplingModule(input_num_channels, downsample) @@ -1085,8 +1104,8 @@ class DownsampledSubformerEncoder(nn.Module): Returns: a Tensor with the same shape as src. """ src_orig = src - if hasattr(self, 'downsampler'): + print("b") indexes, weights, src = self.downsampler(src) pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)