From ed7e01448c1d4e150e89060a5409a3961ae6d1f4 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 17 Dec 2022 13:44:08 +0800 Subject: [PATCH 1/3] Remove query in AttentionDownsample, rename to SimpleDownsample. --- .../pruned_transducer_stateless7/zipformer.py | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index da2ce25ae..b000e9062 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -798,7 +798,7 @@ class AttentionDownsample(torch.nn.Module): Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) + self.bias = nn.Parameter(torch.zeros(downsample)) self.name = None # will be set from training code @@ -833,24 +833,10 @@ class AttentionDownsample(torch.nn.Module): assert src.shape[0] == d_seq_len * ds src = src.reshape(d_seq_len, ds, batch_size, in_channels) - # scores: (d_seq_len, downsample, batch_size, 1) - scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = scores + self.bias.unsqueeze(-1).unsqueeze(-1) - scores = penalize_abs_values_gt(scores, - limit=20.0, - penalty=1.0e-04, - name=self.name) - - dropout = float(self.dropout) - if dropout > 0.0: - # the 0:1, done on the axis of size 'downsample', selects just - # one dimension while keeping the dim. We'll then broadcast when - # we multiply. - dropout_mask = torch.rand_like(scores[:, 0:1]) > dropout - scores = scores * dropout_mask - - weights = scores.softmax(dim=1) + weights = self.bias.softmax(dim=0) + # weights: (downsample, 1, 1) + weights = weights.unsqueeze(-1).unsqueeze(-1) # ans1 is the first `in_channels` channels of the output ans = (src * weights).sum(dim=1) From 86bb0623e9428822e3c79bb1bebc250581f75eb5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 17 Dec 2022 13:45:30 +0800 Subject: [PATCH 2/3] Remove query from AttentionDownsample, rename to SimpleDownsample --- .../ASR/pruned_transducer_stateless7/zipformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index b000e9062..1475bda5c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -209,7 +209,7 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample(encoder_dim[-1], + self.downsample_output = SimpleDownsample(encoder_dim[-1], encoder_dim[-1], downsample=output_downsampling_factor, dropout=dropout) @@ -677,7 +677,7 @@ class DownsampledZipformerEncoder(nn.Module): dropout: FloatLike): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, + self.downsample = SimpleDownsample(input_dim, output_dim, downsample, dropout) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) @@ -741,7 +741,7 @@ class DownsamplingZipformerEncoder(nn.Module): downsample: int): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, downsample) + self.downsample = SimpleDownsample(input_dim, output_dim, downsample) self.encoder = encoder @@ -785,7 +785,7 @@ class DownsamplingZipformerEncoder(nn.Module): return src -class AttentionDownsample(torch.nn.Module): +class SimpleDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ @@ -797,7 +797,7 @@ class AttentionDownsample(torch.nn.Module): """ Require out_channels > in_channels. """ - super(AttentionDownsample, self).__init__() + super(SimpleDownsample, self).__init__() self.bias = nn.Parameter(torch.zeros(downsample)) From cc739b193a6190cf6235a06af8b80ce614caf732 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Sat, 17 Dec 2022 18:21:00 +0800 Subject: [PATCH 3/3] Implement memory-saving measure in randomized modules --- .../pruned_transducer_stateless7/scaling.py | 116 ++++++++++++++++-- 1 file changed, 107 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 43d4e7bec..61171be00 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -20,7 +20,7 @@ from itertools import repeat from typing import Optional, Tuple, Union from functools import reduce import logging - +from torch.cuda.amp import custom_fwd, custom_bwd import random import torch import torch.nn as nn @@ -230,6 +230,96 @@ def random_cast_to_half(x: Tensor, return torch.where(is_too_small, random_val, x).to(torch.float16) + +class CutoffEstimator: + """ + Estimates cutoffs of an arbitrary numerical quantity such that a specified + proportion of items will be above the cutoff on average. + p is the proportion of items that should be above the cutoff. + + """ + def __init__(self, p: float): + self.p = p + # total count of items + self.count = 0 + # total count of items that were above the cutoff + self.count_above = 0 + # initial cutoff value + self.cutoff = 0 + + + def __call__(self, x: float) -> bool: + """ + Returns true if x is above the cutoff. + """ + ans = (x > self.cutoff) + self.count += 1 + if ans: + self.count_above += 1 + cur_p = self.count_above / self.count + delta_p = cur_p - self.p + if (delta_p > 0) == ans: + q = abs(delta_p) + self.cutoff = x * q + self.cutoff * (1-q) + return ans + + +class CachingEvalFunction(torch.autograd.Function): + # @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure + # that the backward path runs with the same autocast context as the forward pass. + @staticmethod + @custom_fwd + def forward(ctx, x: Tensor, m) -> Tensor: + """ + m might be an nn.Module + """ + ctx.x_requires_grad = x.requires_grad + ctx.m = m + # we need any random numbers used in this evaluation and the next evaluation to be identical. + # Caution: this assumes you are not going to use any random numbers from torch (for any purpose + # that matters in the forward pass), e.g. there should be no dropout. + ctx.random_state = random.getstate() + # we are inside torch.no_grad() here, so the following won't create the computation graph. + with torch.no_grad(): + y = m(x) + ctx.save_for_backward(x, y) + return y + + @staticmethod + @custom_bwd + def backward(ctx, y_grad: Tensor): + x, y = ctx.saved_tensors + x = x.detach() + x.requires_grad = ctx.x_requires_grad + m = ctx.m # e.g. a nn.Module + + random_state = random.getstate() + # set the state to what we used in the 1st forward pass. + random.setstate(ctx.random_state) + with torch.enable_grad(): + y2 = m(x) + assert torch.allclose(y, y2, atol=1.0e-02) + # this call to backward() should create grads in the module's parameters + y2.backward(gradient=y_grad) + + # restore the state from before we entered this function + random.setstate(random_state) + + return x.grad, None # x.grad will be None if x.requires_grad is False. + + +def caching_eval(x: Tensor, m: nn.Module) -> Tensor: + if m.training: + # The purpose of this code is to make all parameters of m reachable in + # the computation graph, so that if we give find_unused_parameters=True + # to PyTorch's autograd code it does not assign them zero gradient. + tot = 0.0 + for p in m.parameters(): + tot = tot + 0.0 * p.flatten()[0] + x = x + tot # tot will be 0, this does nothing. + return CachingEvalFunction.apply(x, m) + + class RandomGradFunction(torch.autograd.Function): """ Does nothing in forward pass; in backward pass, gets rid of very small grads using @@ -549,6 +639,9 @@ class ActivationBalancer(torch.nn.Module): if prob is None: prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4) self.prob = prob + # 10% of the time we will return and do nothing because memory usage + # is too high. + self.mem_cutoff = CutoffEstimator(0.1) # actually self.num_channels is no longer needed except for an assertion. self.num_channels = num_channels @@ -562,11 +655,9 @@ class ActivationBalancer(torch.nn.Module): self.scale_gain_factor = scale_gain_factor - - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad: + if (torch.jit.is_scripting() or not x.requires_grad or + (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): return _no_op(x) prob = float(self.prob) @@ -720,7 +811,7 @@ class Whiten(nn.Module): num_groups: int, whitening_limit: FloatLike, prob: Union[float, Tuple[float,float]], - grad_scale: float): + grad_scale: FloatLike): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -745,6 +836,12 @@ class Whiten(nn.Module): assert grad_scale >= 0 self.num_groups = num_groups self.whitening_limit = whitening_limit + self.grad_scale = grad_scale + + # 10% of the time we will return and do nothing because memory usage + # is too high. + self.mem_cutoff = CutoffEstimator(0.1) + if isinstance(prob, float): assert 0 < prob <= 1 self.prob = prob @@ -753,7 +850,6 @@ class Whiten(nn.Module): assert 0 < self.min_prob < self.max_prob <= 1 self.prob = self.max_prob self.name = None # will be set in training loop - self.grad_scale = grad_scale def forward(self, x: Tensor) -> Tensor: @@ -773,7 +869,9 @@ class Whiten(nn.Module): you use the returned value, or the graph will be freed and nothing will happen in backprop. """ - if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: + grad_scale = float(self.grad_scale) + if (not x.requires_grad or random.random() > self.prob or grad_scale == 0 + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): return _no_op(x) else: whitening_limit = float(self.whitening_limit) @@ -789,7 +887,7 @@ class Whiten(nn.Module): return WhiteningPenaltyFunction.apply(x, self.num_groups, whitening_limit, - self.grad_scale, + grad_scale, self.name)