From 92f61281270926700d7ca8cf3320d072fcf7fd33 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Wed, 14 Jun 2023 03:58:07 -0400 Subject: [PATCH] replace some files with symlinks --- .../SURT/dprnn_zipformer/beam_search.py | 6 +- egs/libricss/SURT/dprnn_zipformer/scaling.py | 1534 +-------- .../SURT/dprnn_zipformer/scaling_converter.py | 115 +- .../SURT/dprnn_zipformer/zipformer.py | 2892 +---------------- 4 files changed, 6 insertions(+), 4541 deletions(-) mode change 100644 => 120000 egs/libricss/SURT/dprnn_zipformer/scaling.py mode change 100644 => 120000 egs/libricss/SURT/dprnn_zipformer/scaling_converter.py mode change 100644 => 120000 egs/libricss/SURT/dprnn_zipformer/zipformer.py diff --git a/egs/libricss/SURT/dprnn_zipformer/beam_search.py b/egs/libricss/SURT/dprnn_zipformer/beam_search.py index c30687c08..021641eaa 100644 --- a/egs/libricss/SURT/dprnn_zipformer/beam_search.py +++ b/egs/libricss/SURT/dprnn_zipformer/beam_search.py @@ -120,7 +120,7 @@ def fast_beam_search_nbest_LG( - (5) The path with the largest score is used as the decoding output. Args: model: - An instance of `Transducer`. + An instance of `SURT`. decoding_graph: Decoding graph used for decoding, may be a TrivialGraph or a LG. encoder_out: @@ -705,8 +705,8 @@ def modified_beam_search_LODR( external language model. Args: - model (Transducer): - The transducer model + model (SURT): + The SURT model encoder_out (torch.Tensor): Encoder output in (N,T,C) encoder_out_lens (torch.Tensor): diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling.py b/egs/libricss/SURT/dprnn_zipformer/scaling.py deleted file mode 100644 index 835bf72ca..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/scaling.py +++ /dev/null @@ -1,1533 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import random -from typing import Optional, Tuple, Union - -import torch -import torch.backends.cudnn.rnn as rnn -import torch.nn as nn -import torch.nn.functional as F -from torch import _VF, Tensor - -from icefall.utils import is_jit_tracing - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = x > 0 - if sign_factor is None: - ctx.save_for_backward(xgt0, scale_factor) - else: - ctx.save_for_backward(xgt0, scale_factor, sign_factor) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: - if len(ctx.saved_tensors) == 3: - xgt0, scale_factor, sign_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - sign_factor = sign_factor.unsqueeze(-1) - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - else: - xgt0, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) - - -def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, -) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) - - if min_abs == 0.0: - below_threshold = 0.0 - else: - # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if - # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor - ) - - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor - ) - - return below_threshold - above_threshold - - -def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, -) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) - if min_positive == 0.0: - factor1 = 0.0 - else: - # 0 if proportion_positive >= min_positive, else can be - # as large as max_factor. - factor1 = ( - (min_positive - proportion_positive) * (gain_factor / min_positive) - ).clamp_(min=0, max=max_factor) - - if max_positive == 1.0: - factor2 = 0.0 - else: - # 0 if self.proportion_positive <= max_positive, else can be - # as large as -max_factor. - factor2 = ( - (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) - ).clamp_(min=0, max=max_factor) - sign_factor = factor1 - factor2 - # require min_positive != 0 or max_positive != 1: - assert not isinstance(sign_factor, float) - return sign_factor - - -class ActivationScaleBalancerFunction(torch.autograd.Function): - """ - This object is used in class ActivationBalancer when the user specified - min_positive=0, max_positive=1, so there are no constraints on the signs - of the activations and only the absolute value has a constraint. - """ - - @staticmethod - def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = x > 0 - ctx.save_for_backward(xgt0, sign_factor, scale_factor) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: - xgt0, sign_factor, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - sign_factor = sign_factor.unsqueeze(-1) - scale_factor = scale_factor.unsqueeze(-1) - - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) - - -class RandomClampFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float, - ) -> Tensor: - x_clamped = torch.clamp(x, min=min, max=max) - mask = torch.rand_like(x) < prob - ans = torch.where(mask, x_clamped, x) - if x.requires_grad: - ctx.save_for_backward(ans == x) - ctx.reflect = reflect - if reflect != 0.0: - ans = ans * (1.0 + reflect) - (x * reflect) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - (is_same,) = ctx.saved_tensors - x_grad = ans_grad * is_same.to(ans_grad.dtype) - reflect = ctx.reflect - if reflect != 0.0: - x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return x_grad, None, None, None, None - - -def random_clamp( - x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0, -): - return RandomClampFunction.apply(x, min, max, prob, reflect) - - -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = x_abs < min_abs - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class RandomGradFunction(torch.autograd.Function): - """ - Does nothing in forward pass; in backward pass, gets rid of very small grads using - randomized approach that preserves expectations (intended to reduce roundoff). - """ - - @staticmethod - def forward(ctx, x: Tensor, min_abs: float) -> Tensor: - ctx.min_abs = min_abs - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: - if ans_grad.dtype == torch.float16: - return ( - random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), - None, - ) - else: - return ans_grad, None - - -class RandomGrad(torch.nn.Module): - """ - Gets rid of very small gradients using an expectation-preserving method, intended to increase - accuracy of training when using amp (automatic mixed precision) - """ - - def __init__(self, min_abs: float = 5.0e-06): - super(RandomGrad, self).__init__() - self.min_abs = min_abs - - def forward(self, x: Tensor): - if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): - return x - else: - return RandomGradFunction.apply(x, self.min_abs) - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x.softmax(dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class GradientFilterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - batch_dim: int, # e.g., 1 - threshold: float, # e.g., 10.0 - *params: Tensor, # module parameters - ) -> Tuple[Tensor, ...]: - if x.requires_grad: - if batch_dim < 0: - batch_dim += x.ndim - ctx.batch_dim = batch_dim - ctx.threshold = threshold - return (x,) + params - - @staticmethod - def backward( - ctx, - x_grad: Tensor, - *param_grads: Tensor, - ) -> Tuple[Tensor, ...]: - eps = 1.0e-20 - dim = ctx.batch_dim - norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() - median_norm = norm_of_batch.median() - - cutoff = median_norm * ctx.threshold - inv_mask = (cutoff + norm_of_batch) / (cutoff + eps) - mask = 1.0 / (inv_mask + eps) - x_grad = x_grad * mask - - avg_mask = 1.0 / (inv_mask.mean() + eps) - param_grads = [avg_mask * g for g in param_grads] - - return (x_grad, None, None) + tuple(param_grads) - - -class GradientFilter(torch.nn.Module): - """This is used to filter out elements that have extremely large gradients - in batch and the module parameters with soft masks. - Args: - batch_dim (int): - The batch dimension. - threshold (float): - For each element in batch, its gradient will be - filtered out if the gradient norm is larger than - `grad_norm_threshold * median`, where `median` is the median - value of gradient norms of all elememts in batch. - """ - - def __init__(self, batch_dim: int = 1, threshold: float = 10.0): - super(GradientFilter, self).__init__() - self.batch_dim = batch_dim - self.threshold = threshold - - def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]: - if torch.jit.is_scripting() or is_jit_tracing(): - return (x,) + params - else: - return GradientFilterFunction.apply( - x, - self.batch_dim, - self.threshold, - *params, - ) - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_min: float - eps_max: float - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - eps_min: float = -3.0, - eps_max: float = 3.0, - ) -> None: - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - self.eps_min = eps_min - self.eps_max = eps_max - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - eps = self.eps - if self.training and random.random() < 0.25: - # with probability 0.25, in training mode, clamp eps between the min - # and max; this will encourage it to learn parameters within the - # allowed range by making parameters that are outside the allowed - # range noisy. - - # gradients to allow the parameter to get back into the allowed - # region if it happens to exit it. - eps = eps.clamp(min=self.eps_min, max=self.eps_max) - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() - ) ** -0.5 - return x * scales - - -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -class ScaledConv2d(nn.Conv2d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs, - ): - super(ScaledConv2d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - # see https://github.com/pytorch/pytorch/issues/24135 - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - else: - return bias * bias_scale.exp() - - def _conv_forward(self, input, weight): - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - (0, 0), - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) - - -class ScaledLSTM(nn.LSTM): - # See docs for ScaledLinear. - # This class implements LSTM with scaling mechanism, using `torch._VF.lstm` - # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - grad_norm_threshold: float = 10.0, - **kwargs, - ): - super(ScaledLSTM, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self._scales_names = [] - self._scales = [] - self.batch_dim = int(not self.batch_first) - for name in self._flat_weights_names: - scale_name = name + "_scale" - self._scales_names.append(scale_name) - param = nn.Parameter(initial_scale.clone().detach()) - setattr(self, scale_name, param) - self._scales.append(param) - - self.grad_filter = GradientFilter( - batch_dim=self.batch_dim, threshold=grad_norm_threshold - ) - - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - scale = self.hidden_size**-0.5 - v = scale / std - for idx, name in enumerate(self._flat_weights_names): - if "weight" in name: - nn.init.uniform_(self._flat_weights[idx], -a, a) - with torch.no_grad(): - self._scales[idx] += torch.tensor(v).log() - elif "bias" in name: - nn.init.constant_(self._flat_weights[idx], 0.0) - - def _flatten_parameters(self, flat_weights) -> None: - """Resets parameter data pointer so that they can use faster code paths. - - Right now, this works only if the module is on the GPU and cuDNN is enabled. - Otherwise, it's a no-op. - - This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - """ - # Short-circuits if _flat_weights is only partially instantiated - if len(flat_weights) != len(self._flat_weights_names): - return - - for w in flat_weights: - if not isinstance(w, Tensor): - return - # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN - # or the tensors in flat_weights are of different dtypes - - first_fw = flat_weights[0] - dtype = first_fw.dtype - for fw in flat_weights: - if ( - not isinstance(fw.data, Tensor) - or not (fw.data.dtype == dtype) - or not fw.data.is_cuda - or not torch.backends.cudnn.is_acceptable(fw.data) - ): - return - - # If any parameters alias, we fall back to the slower, copying code path. This is - # a sufficient check, because overlapping parameter buffers that don't completely - # alias would break the assumptions of the uniqueness check in - # Module.named_parameters(). - unique_data_ptrs = set(p.data_ptr() for p in flat_weights) - if len(unique_data_ptrs) != len(flat_weights): - return - - with torch.cuda.device_of(first_fw): - - # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is - # an inplace operation on self._flat_weights - with torch.no_grad(): - if torch._use_cudnn_rnn_flatten_weight(): - num_weights = 4 if self.bias else 2 - if self.proj_size > 0: - num_weights += 1 - torch._cudnn_rnn_flatten_weight( - flat_weights, - num_weights, - self.input_size, - rnn.get_cudnn_mode(self.mode), - self.hidden_size, - self.proj_size, - self.num_layers, - self.batch_first, - bool(self.bidirectional), - ) - - def _get_flat_weights(self): - """Get scaled weights, and resets their data pointer.""" - flat_weights = [] - for idx in range(len(self._flat_weights_names)): - flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) - self._flatten_parameters(flat_weights) - return flat_weights - - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): - # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - # The change for calling `_VF.lstm()` is: - # self._flat_weights -> self._get_flat_weights() - if hx is None: - num_directions = 2 if self.bidirectional else 1 - h_zeros = torch.zeros( - self.num_layers * num_directions, - input.size(self.batch_dim), - self.proj_size if self.proj_size > 0 else self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - c_zeros = torch.zeros( - self.num_layers * num_directions, - input.size(self.batch_dim), - self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - hx = (h_zeros, c_zeros) - - self.check_forward_args(input, hx, None) - - flat_weights = self._get_flat_weights() - input, *flat_weights = self.grad_filter(input, *flat_weights) - - result = _VF.lstm( - input, - hx, - flat_weights, - self.bias, - self.num_layers, - self.dropout, - self.training, - self.bidirectional, - self.batch_first, - ) - - output = result[0] - hidden = result[1:] - return output, hidden - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - sign_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_positive and max_positive - are violated. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - min_prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. Early in training we may use - higher probabilities than this; it will decay to this value. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, - ): - super(ActivationBalancer, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - self.min_prob = min_prob - self.sign_gain_factor = sign_gain_factor - self.scale_gain_factor = scale_gain_factor - - # count measures how many times the forward() function has been called. - # We occasionally sync this to a tensor called `count`, that exists to - # make sure it is synced to disk when we load and save the model. - self.cpu_count = 0 - self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): - return _no_op(x) - - count = self.cpu_count - self.cpu_count += 1 - - if random.random() < 0.01: - # Occasionally sync self.cpu_count with self.count. - # count affects the decay of 'prob'. don't do this on every iter, - # because syncing with the GPU is slow. - self.cpu_count = max(self.cpu_count, self.count.item()) - self.count.fill_(self.cpu_count) - - # the prob of doing some work exponentially decreases from 0.5 till it hits - # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) - - if random.random() < prob: - sign_gain_factor = 0.5 - if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, - ) - else: - sign_factor = None - - scale_factor = _compute_scale_factor( - x.detach(), - self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, - ) - return ActivationBalancerFunction.apply( - x, - scale_factor, - sign_factor, - self.channel_dim, - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float - ) -> Tensor: - ctx.save_for_backward(x) - ctx.num_groups = num_groups - ctx.whitening_limit = whitening_limit - ctx.grad_scale = grad_scale - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, ctx.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" - ) - - (metric - ctx.whitening_limit).relu().backward() - penalty_grad = x_detached.grad - scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) - ) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float, float]], - grad_scale: float, - ): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert whitening_limit >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - if isinstance(prob, float): - assert 0 < prob <= 1 - self.prob = prob - else: - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob < self.max_prob <= 1 - self.prob = self.max_prob - - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: - return _no_op(x) - else: - if hasattr(self, "min_prob") and random.random() < 0.25: - # occasionally switch between min_prob and max_prob, based on whether - # we are above or below the threshold. - if ( - _whitening_metric(x.to(torch.float32), self.num_groups) - > self.whitening_limit - ): - # there would be a change to the grad. - self.prob = self.max_prob - else: - self.prob = self.min_prob - - return WhiteningPenaltyFunction.apply( - x, self.num_groups, self.whitening_limit, self.grad_scale - ) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor): - ctx.y_shape = y.shape - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones( - ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device - ) - - -def with_loss(x, y): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y) - - -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class MaxEig(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to discourage - that any given direction in activation space accounts for more than - a specified proportion of the covariance (e.g. 0.2). - - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - max_var_per_eig: the maximum proportion of the variance of the - features/channels, after mean subtraction, that can come from - any given eigenvalue. - min_prob: the minimum probability with which we apply this during any invocation - of forward(), assuming last time we applied the constraint it was - not active; supplied for speed. - scale: determines the scale with which we modify the gradients, relative - to the existing / unmodified gradients - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, - ): - super(MaxEig, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.scale = scale - assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels - self.max_var_per_eig = max_var_per_eig - - # we figure out the dominant direction using the power method: starting with - # a random vector, keep multiplying by the covariance and renormalizing. - with torch.no_grad(): - # arbitrary.. would use randn() but want to leave the rest of the model's - # random parameters unchanged for comparison - direction = torch.arange(num_channels).to(torch.float) - direction = direction / direction.norm() - self.register_buffer("max_eig_direction", direction) - - self.min_prob = min_prob - # cur_prob is the current probability we'll use to apply the ActivationBalancer. - # We'll regress this towards prob, each tiem we try to apply it and it is not - # active. - self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or self.max_var_per_eig <= 0 - or random.random() > self.cur_prob - or torch.jit.is_tracing() - ): - return _no_op(x) - - with torch.cuda.amp.autocast(enabled=False): - eps = 1.0e-20 - orig_x = x - x = x.to(torch.float32) - with torch.no_grad(): - x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) - x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs( - x, self.max_eig_direction - ) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - - # ensure new direction is nonzero even if x == 0, by including `direction`. - self._set_direction(0.1 * self.max_eig_direction + new_direction) - - if random.random() < 0.01 or __name__ == "__main__": - logging.info( - f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" - ) - - if variance_proportion >= self.max_var_per_eig: - # The constraint is active. Note, we should quite rarely - # reach here, only near the beginning of training if we are - # starting to diverge, should this constraint be active. - cur_prob = self.cur_prob - self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply( - orig_x, coeffs, new_direction, self.channel_dim, self.scale - ) - else: - # let self.cur_prob exponentially approach self.min_prob, as - # long as the constraint is inactive. - self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob - return orig_x - - def _set_direction(self, direction: Tensor): - """ - Sets self.max_eig_direction to a normalized version of `direction` - """ - direction = direction.detach() - direction = direction / direction.norm() - direction_sum = direction.sum().item() - if direction_sum - direction_sum == 0: # no inf/nan - self.max_eig_direction[:] = direction - else: - logging.info( - f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}" - ) - - def _find_direction_coeffs( - self, x: Tensor, prev_direction: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ - (num_frames, num_channels) = x.shape - assert num_channels > 1 and num_frames > 1 - assert prev_direction.shape == (num_channels,) - # `coeffs` are the coefficients of `prev_direction` in x. - # actually represent the coeffs up to a constant positive factor. - coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) - return cur_direction, coeffs - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - x_dtype = x.dtype - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = y * (1 - s) + s - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.043637 - ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -def _test_max_eig(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = MaxEig( - num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad, atol=1.0e-02) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_activation_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_sign: x = ", x) - print("_test_activation_balancer_sign: y grad = ", y_grad) - print("_test_activation_balancer_sign: x grad = ", x.grad) - - -def _test_activation_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - min_prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_magnitude: x = ", x) - print("_test_activation_balancer_magnitude: y grad = ", y_grad) - print("_test_activation_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = (1.2 - (-0.043637)) / 255.0 - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -def _test_scaled_lstm(): - N, L = 2, 30 - dim_in, dim_hidden = 10, 20 - m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True) - x = torch.randn(L, N, dim_in) - h0 = torch.randn(1, N, dim_hidden) - c0 = torch.randn(1, N, dim_hidden) - y, (h, c) = m(x, (h0, c0)) - assert y.shape == (L, N, dim_hidden) - assert h.shape == (1, N, dim_hidden) - assert c.shape == (1, N, dim_hidden) - - -def _test_grad_filter(): - threshold = 50.0 - time, batch, channel = 200, 5, 128 - grad_filter = GradientFilter(batch_dim=1, threshold=threshold) - - for i in range(2): - x = torch.randn(time, batch, channel, requires_grad=True) - w = nn.Parameter(torch.ones(5)) - b = nn.Parameter(torch.zeros(5)) - - x_out, w_out, b_out = grad_filter(x, w, b) - - w_out_grad = torch.randn_like(w) - b_out_grad = torch.randn_like(b) - x_out_grad = torch.rand_like(x) - if i % 2 == 1: - # The gradient norm of the first element must be larger than - # `threshold * median`, where `median` is the median value - # of gradient norms of all elements in batch. - x_out_grad[:, 0, :] = torch.full((time, channel), threshold) - - torch.autograd.backward( - [x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad] - ) - - print( - "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa - i % 2 == 1, - ) - - print( - "_test_grad_filter: x_out_grad norm = ", - (x_out_grad**2).mean(dim=(0, 2)).sqrt(), - ) - print( - "_test_grad_filter: x.grad norm = ", - (x.grad**2).mean(dim=(0, 2)).sqrt(), - ) - print("_test_grad_filter: w_out_grad = ", w_out_grad) - print("_test_grad_filter: w.grad = ", w.grad) - print("_test_grad_filter: b_out_grad = ", b_out_grad) - print("_test_grad_filter: b.grad = ", b.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_softmax() - _test_whiten() - _test_max_eig() - _test_activation_balancer_sign() - _test_activation_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() - _test_scaled_lstm() - _test_grad_filter() diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling.py b/egs/libricss/SURT/dprnn_zipformer/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/libricss/SURT/dprnn_zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py b/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py deleted file mode 100644 index 56165d1f9..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file replaces various modules in a model. -Specifically, ActivationBalancer is replaced with an identity operator; -Whiten is also replaced with an identity operator; -BasicNorm is replaced by a module with `exp` removed. -""" - -import copy -from typing import List - -import torch -import torch.nn as nn -from scaling import ActivationBalancer, BasicNorm, Whiten - - -class NonScaledNorm(nn.Module): - """See BasicNorm for doc""" - - def __init__( - self, - num_channels: int, - eps_exp: float, - channel_dim: int = -1, # CAUTION: see documentation. - ): - super().__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_exp = eps_exp - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if not torch.jit.is_tracing(): - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp - ).pow(-0.5) - return x * scales - - -def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(BasicNorm) - norm = NonScaledNorm( - num_channels=basic_norm.num_channels, - eps_exp=basic_norm.eps.data.exp().item(), - channel_dim=basic_norm.channel_dim, - ) - return norm - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, BasicNorm): - d[name] = convert_basic_norm(m) - elif isinstance(m, (ActivationBalancer, Whiten)): - d[name] = nn.Identity() - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py b/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/libricss/SURT/dprnn_zipformer/zipformer.py b/egs/libricss/SURT/dprnn_zipformer/zipformer.py deleted file mode 100644 index a5c422959..000000000 --- a/egs/libricss/SURT/dprnn_zipformer/zipformer.py +++ /dev/null @@ -1,2891 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import itertools -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - Identity, - MaxEig, - ScaledConv1d, - Whiten, - _diag, - penalize_abs_values_gt, - random_clamp, - softmax, -) -from torch import Tensor, nn - -from icefall.utils import make_pad_mask, subsequent_chunk_mask - - -def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Note: - It is the inverse of :func:`unstack_states`. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - ``states[i][0:num_encoders]`` is the cached numbers of past frames. - ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A new state corresponding to a batch of utterances. - See the input argument of :func:`unstack_states` for the meaning - of the returned tensor. - """ - batch_size = len(state_list) - assert len(state_list[0]) % 7 == 0, len(state_list[0]) - num_encoders = len(state_list[0]) // 7 - - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - # For cached_len - len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] - for i in range(num_encoders): - # len_avg: (num_layers, batch_size) - len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) - cached_len.append(len_avg) - - # For cached_avg - avg_list = [ - state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # avg: (num_layers, batch_size, D) - avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) - cached_avg.append(avg) - - # For cached_key - key_list = [ - state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # key: (num_layers, left_context_size, batch_size, D) - key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) - cached_key.append(key) - - # For cached_val - val_list = [ - state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val: (num_layers, left_context_size, batch_size, D) - val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) - cached_val.append(val) - - # For cached_val2 - val2_list = [ - state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val2: (num_layers, left_context_size, batch_size, D) - val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) - cached_val2.append(val2) - - # For cached_conv1 - conv1_list = [ - state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv1: (num_layers, batch_size, D, kernel-1) - conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) - cached_conv1.append(conv1) - - # For cached_conv2 - conv2_list = [ - state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv2: (num_layers, batch_size, D, kernel-1) - conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A list of states. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - """ - assert len(states) % 7 == 0, len(states) - num_encoders = len(states) // 7 - ( - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) - - batch_size = cached_len[0].shape[1] - - len_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_len[i]: (num_layers, batch_size) - len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - len_list[n].append(len_avg[n]) - - avg_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_avg[i]: (num_layers, batch_size, D) - avg = cached_avg[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - avg_list[n].append(avg[n]) - - key_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_key[i]: (num_layers, left_context, batch_size, D) - key = cached_key[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - key_list[n].append(key[n]) - - val_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val[i]: (num_layers, left_context, batch_size, D) - val = cached_val[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val_list[n].append(val[n]) - - val2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val2[i]: (num_layers, left_context, batch_size, D) - val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val2_list[n].append(val2[n]) - - conv1_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) - conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv1_list[n].append(conv1[n]) - - conv2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) - conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv2_list[n].append(conv2[n]) - - state_list = [ - ( - len_list[i] - + avg_list[i] - + key_list[i] - + val_list[i] - + val2_list[i] - + conv1_list[i] - + conv2_list[i] - ) - for i in range(batch_size) - ] - return state_list - - -class Zipformer(EncoderInterface): - """ - Args: - num_features (int): Number of input features - d_model: (int,int): embedding dimension of 2 encoder stacks - attention_dim: (int,int): attention dimension of 2 encoder stacks - nhead (int, int): number of heads - dim_feedforward (int, int): feedforward dimension in 2 encoder stacks - num_encoder_layers (int): number of encoder layers - dropout (float): dropout rate - cnn_module_kernels (int): Kernel size of convolution module - warmup_batches (float): number of batches to warm up over - """ - - def __init__( - self, - num_features: int, - output_downsampling_factor: int = 2, - encoder_dims: Tuple[int] = (384, 384), - attention_dim: Tuple[int] = (256, 256), - encoder_unmasked_dims: Tuple[int] = (256, 256), - zipformer_downsampling_factors: Tuple[int] = (2, 4), - nhead: Tuple[int] = (8, 8), - feedforward_dim: Tuple[int] = (1536, 2048), - num_encoder_layers: Tuple[int] = (12, 12), - dropout: float = 0.1, - cnn_module_kernels: Tuple[int] = (31, 31), - pos_dim: int = 4, - num_left_chunks: int = 4, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 50, - decode_chunk_size: int = 16, - warmup_batches: float = 4000.0, - ) -> None: - super(Zipformer, self).__init__() - - self.num_features = num_features - assert 0 < encoder_dims[0] <= encoder_dims[1] - self.encoder_dims = encoder_dims - self.encoder_unmasked_dims = encoder_unmasked_dims - self.zipformer_downsampling_factors = zipformer_downsampling_factors - self.output_downsampling_factor = output_downsampling_factor - - self.num_left_chunks = num_left_chunks - self.short_chunk_threshold = short_chunk_threshold - self.short_chunk_size = short_chunk_size - - # Used in decoding - self.decode_chunk_size = decode_chunk_size - - self.left_context_len = self.decode_chunk_size * self.num_left_chunks - - # will be written to, see set_batch_count() - self.batch_count = 0 - self.warmup_end = warmup_batches - - for u, d in zip(encoder_unmasked_dims, encoder_dims): - assert u <= d, (u, d) - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7)//2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7)//2 - # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout - ) - - # each one will be ZipformerEncoder or DownsampledZipformerEncoder - encoders = [] - - self.num_encoder_layers = num_encoder_layers - self.num_encoders = len(encoder_dims) - self.attention_dims = attention_dim - self.cnn_module_kernels = cnn_module_kernels - for i in range(self.num_encoders): - encoder_layer = ZipformerEncoderLayer( - encoder_dims[i], - attention_dim[i], - nhead[i], - feedforward_dim[i], - dropout, - cnn_module_kernels[i], - pos_dim, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = ZipformerEncoder( - encoder_layer, - num_encoder_layers[i], - dropout, - warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), - ) - - if zipformer_downsampling_factors[i] != 1: - encoder = DownsampledZipformerEncoder( - encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], - output_dim=encoder_dims[i], - downsample=zipformer_downsampling_factors[i], - ) - encoders.append(encoder) - self.encoders = nn.ModuleList(encoders) - - # initializes self.skip_layers and self.skip_modules - self._init_skip_modules() - - self.downsample_output = AttentionDownsample( - encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor - ) - - def _get_layer_skip_dropout_prob(self): - if not self.training: - return 0.0 - batch_count = self.batch_count - min_dropout_prob = 0.025 - - if batch_count > self.warmup_end: - return min_dropout_prob - else: - return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) - - def _init_skip_modules(self): - """ - If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, - we combine the outputs of layers 1 and 4. - """ - skip_layers = [] - skip_modules = [] - z = self.zipformer_downsampling_factors - for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: - skip_layers.append(None) - skip_modules.append(SimpleCombinerIdentity()) - else: - # TEMP - for j in range(i - 2, -1, -1): - if z[j] <= z[i] or j == 0: - # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." - ) - skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) - break - self.skip_layers = skip_layers - self.skip_modules = nn.ModuleList(skip_modules) - - def get_feature_masks(self, x: torch.Tensor) -> List[float]: - # Note: The actual return type is Union[List[float], List[Tensor]], - # but to make torch.jit.script() work, we use List[float] - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all encoder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoder dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_downsampling_factors times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (num_frames, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) - - max_downsampling_factor = max(self.zipformer_downsampling_factors) - - num_frames_max = num_frames0 + max_downsampling_factor - 1 - - feature_mask_dropout_prob = 0.15 - - # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) - - feature_masks = [] - for i in range(num_encoders): - ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds - - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) - num_frames = (num_frames0 + ds - 1) // ds - frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) - u = self.encoder_unmasked_dims[i] - feature_mask[:, :, u:] *= frame_mask - feature_masks.append(feature_mask) - - return feature_masks - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - chunk_size: - The chunk size used in evaluation mode. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) - - outputs = [] - feature_masks = self.get_feature_masks(x) - - if self.training: - # Training mode - max_ds = max(self.zipformer_downsampling_factors) - # Generate dynamic chunk-wise attention mask during training - max_len = x.size(0) // max_ds - short_chunk_size = self.short_chunk_size // max_ds - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - # Full attention - chunk_size = x.size(0) - else: - # Chunk-wise attention - chunk_size = chunk_size % short_chunk_size + 1 - chunk_size *= max_ds - else: - chunk_size = self.decode_chunk_size - # Evaluation mode - for ds in self.zipformer_downsampling_factors: - assert chunk_size % ds == 0, (chunk_size, ds) - - attn_mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, - device=x.device, - ) - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - ds = self.zipformer_downsampling_factors[i] - k = self.skip_layers[i] - if isinstance(k, int): - layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): - x = skip_module(outputs[k], x) - elif (not self.training) or random.random() > layer_skip_dropout_prob: - x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - attn_mask=attn_mask[::ds, ::ds], - ) - outputs.append(x) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return x, lengths - - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: List[Tensor], - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - seq_len is the input chunk length. - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - Return a tuple containing 3 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states. - """ - assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) - - cached_len = states[: self.num_encoders] - cached_avg = states[self.num_encoders : 2 * self.num_encoders] - cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] - cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] - cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] - cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] - cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] - - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - - outputs = [] - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - k = self.skip_layers[i] - if isinstance(k, int): - x = skip_module(outputs[k], x) - x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( - x, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - outputs.append(x) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = ( - new_cached_len - + new_cached_avg - + new_cached_key - + new_cached_val - + new_cached_val2 - + new_cached_conv1 - + new_cached_conv2 - ) - return x, lengths, new_states - - @torch.jit.export - def get_init_state( - self, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - """ - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - left_context_len = self.decode_chunk_size * self.num_left_chunks - - for i, encoder in enumerate(self.encoders): - num_layers = encoder.num_layers - ds = self.zipformer_downsampling_factors[i] - - len_avg = torch.zeros(num_layers, 1, dtype=torch.int64, device=device) - cached_len.append(len_avg) - - avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) - cached_avg.append(avg) - - key = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim, - device=device, - ) - cached_key.append(key) - - val = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val.append(val) - - val2 = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val2.append(val2) - - conv1 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv1.append(conv1) - - conv2 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -class ZipformerEncoderLayer(nn.Module): - """ - ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, - ) -> None: - super(ZipformerEncoderLayer, self).__init__() - - self.d_model = d_model - self.attention_dim = attention_dim - self.cnn_module_kernel = cnn_module_kernel - - # will be written to, see set_batch_count() - self.batch_count = 0 - - self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, - ) - - self.pooling = PoolingModule(d_model) - - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) - - self.norm_final = BasicNorm(d_model) - - self.bypass_scale = nn.Parameter(torch.tensor(0.5)) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_abs=6.0, - ) - self.whiten = Whiten( - num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 - ) - - def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - if random.random() < 0.1: - # ensure we get grads if self.bypass_scale becomes out of range - return self.bypass_scale - # hardcode warmup period for bypass scale - warmup_period = 20000.0 - initial_clamp_min = 0.75 - final_clamp_min = 0.25 - if self.batch_count > warmup_period: - clamp_min = final_clamp_min - else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) - return self.bypass_scale.clamp(min=clamp_min, max=1.0) - - def get_dynamic_dropout_rate(self): - # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this - # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable - # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: - return 0.0 - warmup_period = 2000.0 - initial_dropout_rate = 0.2 - final_dropout_rate = 0.0 - if self.batch_count > warmup_period: - return final_dropout_rate - else: - return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate - ) * (self.batch_count / warmup_period) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - batch_split: if not None, this layer will only be applied to - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - # dropout rate for submodules that interact with time. - dynamic_dropout = self.get_dynamic_dropout_rate() - - # pooling module - if torch.jit.is_scripting(): - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - elif random.random() >= dynamic_dropout: - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - - if torch.jit.is_scripting(): - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - src = src + self.self_attn.forward2(src, attn_weights) - - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - else: - use_self_attn = random.random() >= dynamic_dropout - if use_self_attn: - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - if random.random() >= dynamic_dropout: - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - if use_self_attn: - src = src + self.self_attn.forward2(src, attn_weights) - - if random.random() >= dynamic_dropout: - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.get_bypass_scale() - - return self.whiten(src) - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - cached_len: processed number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor of left context for the first attention module. - cached_val: cached value tensor of left context for the first attention module. - cached_val2: cached value tensor of left context for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - pos_emb: (N, left_context_len+2*S-1, E) - cached_len: (N,) - N is the batch size. - cached_avg: (N, C). - N is the batch size, C is the feature dimension. - cached_key: (left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - src_pool, cached_len, cached_avg = self.pooling.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - ) - src = src + src_pool - - ( - src_attn, - attn_weights, - cached_key, - cached_val, - ) = self.self_attn.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - cached_val=cached_val, - ) - src = src + src_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - src_attn, cached_val2 = self.self_attn.streaming_forward2( - src, - attn_weights, - cached_val=cached_val2, - ) - src = src + src_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.bypass_scale - - return ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class ZipformerEncoder(nn.Module): - r"""ZipformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ZipformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - ) -> None: - super().__init__() - # will be written to, see set_batch_count() Note: in inference time this - # may be zero but should be treated as large, we can check if - # self.training is true. - self.batch_count = 0 - self.warmup_begin = warmup_begin - self.warmup_end = warmup_end - # module_seed is for when we need a random number that is unique to the module but - # shared across jobs. It's used to randomly select how many layers to drop, - # so that we can keep this consistent across worker tasks (for efficiency). - self.module_seed = torch.randint(0, 1000, ()).item() - - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - self.d_model = encoder_layer.d_model - self.attention_dim = encoder_layer.attention_dim - self.cnn_module_kernel = encoder_layer.cnn_module_kernel - - assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin - for i in range(num_layers): - self.layers[i].warmup_begin = cur_begin - cur_begin += delta - self.layers[i].warmup_end = cur_begin - - def get_layers_to_drop(self, rnd_seed: int): - ans = set() - if not self.training: - return ans - - batch_count = self.batch_count - num_layers = len(self.layers) - - def get_layerdrop_prob(layer: int) -> float: - layer_warmup_begin = self.layers[layer].warmup_begin - layer_warmup_end = self.layers[layer].warmup_end - - initial_layerdrop_prob = 0.5 - final_layerdrop_prob = 0.05 - - if batch_count == 0: - # As a special case, if batch_count == 0, return 0 (drop no - # layers). This is rather ugly, I'm afraid; it is intended to - # enable our scan_pessimistic_batches_for_oom() code to work correctly - # so if we are going to get OOM it will happen early. - # also search for 'batch_count' with quotes in this file to see - # how we initialize the warmup count to a random number between - # 0 and 10. - return 0.0 - elif batch_count < layer_warmup_begin: - return initial_layerdrop_prob - elif batch_count > layer_warmup_end: - return final_layerdrop_prob - else: - # linearly interpolate - t = (batch_count - layer_warmup_begin) / layer_warmup_end - assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) - - shared_rng = random.Random(batch_count + self.module_seed) - independent_rng = random.Random(rnd_seed) - - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] - tot = sum(layerdrop_probs) - # Instead of drawing the samples independently, we first randomly decide - # how many layers to drop out, using the same random number generator between - # jobs so that all jobs drop out the same number (this is for speed). - # Then we use an approximate approach to drop out the individual layers - # with their specified probs while reaching this exact target. - num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) - - layers = list(range(num_layers)) - independent_rng.shuffle(layers) - - # go through the shuffled layers until we get the required number of samples. - if num_to_drop > 0: - for layer in itertools.cycle(layers): - if independent_rng.random() < layerdrop_probs[layer]: - ans.add(layer) - if len(ans) == num_to_drop: - break - if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) - return ans - - def forward( - self, - src: Tensor, - # Note: The type of feature_mask should be Union[float, Tensor], - # but to make torch.jit.script() work, we use `float` here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: (x, x_no_combine), both of shape (S, N, E) - """ - pos_emb = self.encoder_pos(src) - output = src - - if torch.jit.is_scripting(): - layers_to_drop = [] - else: - rnd_seed = src.numel() + random.randint(0, 1000) - layers_to_drop = self.get_layers_to_drop(rnd_seed) - - output = output * feature_mask - - for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): - if i in layers_to_drop: - continue - output = mod( - output, - pos_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - output = output * feature_mask - - return output - - @torch.jit.export - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - cached_len: number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor for first attention module. - cached_val: cached value tensor for first attention module. - cached_val2: cached value tensor for second attention module. - cached_conv1: cached left contexts for the first convolution module. - cached_conv2: cached left contexts for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (num_layers,) - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - - Returns: A tuple of 8 tensors: - - output tensor - - updated cached number of past frames. - - updated cached average of past frames. - - updated cached key tensor of of the first attention module. - - updated cached value tensor of of the first attention module. - - updated cached value tensor of of the second attention module. - - updated cached left contexts of the first convolution module. - - updated cached left contexts of the second convolution module. - """ - assert cached_len.size(0) == self.num_layers, ( - cached_len.size(0), - self.num_layers, - ) - assert cached_avg.size(0) == self.num_layers, ( - cached_avg.size(0), - self.num_layers, - ) - assert cached_key.size(0) == self.num_layers, ( - cached_key.size(0), - self.num_layers, - ) - assert cached_val.size(0) == self.num_layers, ( - cached_val.size(0), - self.num_layers, - ) - assert cached_val2.size(0) == self.num_layers, ( - cached_val2.size(0), - self.num_layers, - ) - assert cached_conv1.size(0) == self.num_layers, ( - cached_conv1.size(0), - self.num_layers, - ) - assert cached_conv2.size(0) == self.num_layers, ( - cached_conv2.size(0), - self.num_layers, - ) - - left_context_len = cached_key.shape[1] - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - for i, mod in enumerate(self.layers): - output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( - output, - pos_emb, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - return ( - output, - torch.stack(new_cached_len, dim=0), - torch.stack(new_cached_avg, dim=0), - torch.stack(new_cached_key, dim=0), - torch.stack(new_cached_val, dim=0), - torch.stack(new_cached_val2, dim=0), - torch.stack(new_cached_conv1, dim=0), - torch.stack(new_cached_conv2, dim=0), - ) - - -class DownsampledZipformerEncoder(nn.Module): - r""" - DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int - ): - super(DownsampledZipformerEncoder, self).__init__() - self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, downsample) - self.encoder = encoder - self.num_layers = encoder.num_layers - self.d_model = encoder.d_model - self.attention_dim = encoder.attention_dim - self.cnn_module_kernel = encoder.cnn_module_kernel - self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) - - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. feature_mask is expected to be already downsampled by - self.downsample_factor. - attn_mask: attention mask (optional). Should be downsampled already. - src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. - - Shape: - src: (S, N, E). - attn_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - src = self.encoder( - src, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - cached_avg: cached average value of past frames. - cached_len: length of past frames. - cached_key: cached key tensor for the first attention module. - cached_val: cached value tensor for the first attention module. - cached_val2: cached value tensor for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (N,) - N is the batch size. - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = self.encoder.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - cached_key=cached_key, - cached_val=cached_val, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return ( - self.out_combiner(src_orig, src), - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class AttentionDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - - def __init__(self, in_channels: int, out_channels: int, downsample: int): - super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) - - # fill in the extra dimensions with a projection of the input - if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) - else: - self.extra_proj = None - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, 1, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, out_channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - scores = (src * self.query).sum(dim=-1, keepdim=True) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) - - weights = scores.softmax(dim=1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) - - if self.extra_proj is not None: - ans2 = self.extra_proj(src) - ans = torch.cat((ans, ans2), dim=2) - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() - self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.bias.shape[0] - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src + self.bias.unsqueeze(1) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class SimpleCombinerIdentity(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - return src1 - - -class SimpleCombiner(torch.nn.Module): - """ - A very simple way of combining 2 vectors of 2 different dims, via a - learned weighted combination in the shared part of the dim. - Args: - dim1: the dimension of the first input, e.g. 256 - dim2: the dimension of the second input, e.g. 384. - The output will have the same dimension as dim2. - """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): - super(SimpleCombiner, self).__init__() - assert dim2 >= dim1, (dim2, dim1) - self.weight1 = nn.Parameter(torch.zeros(())) - self.min_weight = min_weight - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - """ - src1: (*, dim1) - src2: (*, dim2) - - Returns: a tensor of shape (*, dim2) - """ - assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) - - weight1 = self.weight1 - if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) - - src1 = src1 * weight1 - src2 = src2 * (1.0 - weight1) - - src1_dim = src1.shape[-1] - src2_dim = src2.shape[-1] - if src1_dim != src2_dim: - if src1_dim < src2_dim: - src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) - else: - src1 = src1[:src2_dim] - - return src1 + src2 - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, - d_model: int, - dropout_rate: float, - max_len: int = 5000, - ) -> None: - """Construct a PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - x_size_left = x.size(0) + left_context_len - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_left * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use positive relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tensor: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). - - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_left - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(0), - ] - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: total dimension of the model. - attention_dim: dimension in the attention module, may be less or more than embed_dim - but must be a multiple of num_heads. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - attention_dim: int, - num_heads: int, - pos_dim: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.attention_dim = attention_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = attention_dim // num_heads - self.pos_dim = pos_dim - assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim # query, key - + attention_dim // 2 # value - + pos_dim * num_heads # positional encoding query - ) - - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 - ) - - # self.whiten_values is applied on the values in forward(); - # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnosics only, see --print-diagnostics option. - # they only copy their inputs. - self.copy_pos_query = Identity() - self.copy_query = Identity() - - self.out_proj = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - - self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Returns: (attn_output, attn_weights) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - """ - x, weights = self.multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - attn_mask=attn_mask, - ) - return x, weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. - - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. - - - Returns: (attn_output, attn_weights, cached_key, cached_val) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of - left context - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of - """ - ( - x, - weights, - cached_key, - cached_val, - ) = self.streaming_multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.out_proj.weight, - self.out_proj.bias, - cached_key=cached_key, - cached_val=cached_val, - ) - return x, weights, cached_key, cached_val - - def multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - k = self.whiten_keys(k) # does nothing in the forward pass. - v = self.whiten_values(v) # does nothing in the forward pass. - q = self.copy_query(q) # for diagnostics only, does nothing. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - seq_len, - seq_len, - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(seq_len, bsz, num_heads, head_dim) - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == seq_len, "{} == {}".format( - key_padding_mask.size(1), seq_len - ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (batch_size, num_heads, time1, n) = pos_weights.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_weights = pos_weights.reshape(-1, n) - pos_weights = torch.gather(pos_weights, dim=1, index=indexes) - pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) - else: - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - if not torch.jit.is_scripting(): - if training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be too large. - # It incurs a penalty if any of them has an absolute value greater than 50.0. - # this should be outside the normal range of the attention scores. We use - # this mechanism instead of, say, a limit on entropy, because once the entropy - # gets very small gradients through the softmax can become very small, and - # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights = attn_output_weights.masked_fill( - attn_mask, float("-inf") - ) - else: - attn_output_weights = attn_output_weights + attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - # If we are using chunk-wise attention mask and setting a limited - # num_left_chunks, the attention may only see the padding values which - # will also be masked out by `key_padding_mask`. At this circumstances, - # the whole column of `attn_output_weights` will be `-inf` - # (i.e. be `nan` after softmax). So we fill `0.0` at the masking - # positions to avoid invalid loss value below. - if ( - attn_mask is not None - and attn_mask.dtype == torch.bool - and key_padding_mask is not None - ): - if attn_mask.size(0) != 1: - attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) - - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights - - def streaming_multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - out_proj_weight, out_proj_bias: the output projection weight and bias. - cached_key: cached attention key tensor of left context. - cached_val: cached attention value tensor of left context. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - left_context_len = cached_key.shape[0] - assert left_context_len > 0, left_context_len - assert cached_key.shape[0] == cached_val.shape[0], ( - cached_key.shape, - cached_val.shape, - ) - # Pad cached left contexts - k = torch.cat([cached_key, k], dim=0) - v = torch.cat([cached_val, v], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - cached_val = v[-left_context_len:, ...] - - # The length of key and value - kv_len = k.shape[0] - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(kv_len, bsz, num_heads, head_dim) - v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 + left_context_len - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (batch_size, num_heads, time1, n) = pos_weights.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(kv_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_weights = pos_weights.reshape(-1, n) - pos_weights = torch.gather(pos_weights, dim=1, index=indexes) - pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len) - else: - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, kv_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights, cached_key, cached_val - - def forward2( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - Returns: - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - v = self.whiten_values2(v) # does nothing in the forward pass. - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - if not torch.jit.is_scripting(): - if random.random() < 0.001 or __name__ == "__main__": - self._print_attn_stats(attn_weights, attn_output) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output) - - def streaming_forward2( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - cached_val: cached attention value tensor of left context. - Returns: - - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - - updated cached attention value tensor of left context. - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - - left_context_len = cached_val.shape[0] - assert left_context_len > 0, left_context_len - v = torch.cat([cached_val, v], dim=0) - cached_val = v[-left_context_len:] - - seq_len2 = left_context_len + seq_len - v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output), cached_val - - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): - # attn_weights: (batch_size * num_heads, seq_len, seq_len) - # attn_output: (bsz * num_heads, seq_len, head_dim) - (n, seq_len, head_dim) = attn_output.shape - num_heads = self.num_heads - bsz = n // num_heads - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) - attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) - attn_output_mean = attn_output.mean(dim=1, keepdim=True) - attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) - # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") - - attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) - embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" - ) - - -class PoolingModule(nn.Module): - """ - Averages the input over the time dimension and project with a square matrix. - """ - - def __init__(self, d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - x: a Tensor of shape (T, N, C) - src_key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked - positions. - - Returns: - - output, a Tensor of shape (T, N, C). - """ - if src_key_padding_mask is not None: - # False in padding positions - padding_mask = src_key_padding_mask.logical_not().to(x.dtype) # (N, T) - # Cumulated numbers of frames from start - cum_mask = padding_mask.cumsum(dim=1) # (N, T) - x = x.cumsum(dim=0) # (T, N, C) - pooling_mask = padding_mask / cum_mask - pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - else: - num_frames = x.shape[0] - cum_mask = torch.arange(1, num_frames + 1).unsqueeze(1) # (T, 1) - x = x.cumsum(dim=0) # (T, N, C) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask - - x = self.proj(x) - return x - - def streaming_forward( - self, - x: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - x: a Tensor of shape (T, N, C) - cached_len: a Tensor of int, of shape (N,), containing the number of - past frames in batch. - cached_avg: a Tensor of shape (N, C), the average over all past frames - in batch. - - Returns: - A tuple of 2 tensors: - - output, a Tensor of shape (T, N, C). - - updated cached_avg, a Tensor of shape (N, C). - """ - x = x.cumsum(dim=0) # (T, N, C) - x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) - # Cumulated numbers of frames from start - cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) - cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - - cached_len = cached_len + x.size(0) - cached_avg = x[-1] - - x = self.proj(x) - return x, cached_len, cached_avg - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) - self.activation = DoubleSwish() - self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.balancer(x) - x = self.activation(x) - x = self.dropout(x) - x = self.out_proj(x) - return x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0, kernel_size - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer( - 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, - ) - - # Will pad cached left context - self.lorder = kernel_size - 1 - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=0, - groups=channels, - bias=bias, - ) - - self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, - max_abs=20.0, - ) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains bool in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - # 1D Depthwise Conv - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch: - (batch, #time), contains bool in masked positions. - cache: Cached left context for depthwise_conv, with shape of - (batch, channels, #kernel_size-1). Only used in real streaming decoding. - - Returns: - A tuple of 2 tensors: - - Output tensor (#time, batch, channels). - - New cached left context, with shape of (batch, channels, #kernel_size-1). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - assert cache.shape == (x.size(0), x.size(1), self.lorder), ( - cache.shape, - (x.size(0), x.size(1), self.lorder), - ) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[:, :, -self.lorder :] - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1), cache - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: float = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-7)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer2_channels: - Number of channels in layer2 - layer3_channels: - Number of channels in layer3 - """ - assert in_channels >= 7, in_channels - super().__init__() - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ActivationBalancer(layer1_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - ActivationBalancer(layer2_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - ActivationBalancer(layer3_channels, channel_dim=1), - DoubleSwish(), - ) - out_height = (((in_channels - 1) // 2) - 1) // 2 - self.out = ScaledLinear(out_height * layer3_channels, out_channels) - self.dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, (T-7)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) - # Now x is of shape (N, (T-7)//2, odim) - x = self.dropout(x) - return x - - -def _test_zipformer_main(): - feature_dim = 50 - batch_size = 5 - seq_len = 47 - feature_dim = 50 - # Just make sure the forward pass runs. - - c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - decode_chunk_size=4, - ) - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -def _test_conv2d_subsampling(): - num_features = 80 - encoder_dims = 384 - dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) - for i in range(20, 40): - x = torch.rand(2, i, num_features) - y = encoder_embed(x) - assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - - -def _test_pooling_module(): - N, S, C = 2, 12, 32 - chunk_len = 4 - m = PoolingModule(d_model=C) - - # test chunk-wise forward with padding_mask - x = torch.randn(S, N, C) - y = m(x) - cached_len = torch.zeros(N, dtype=torch.int32) - cached_avg = torch.zeros(N, C) - for i in range(S // chunk_len): - start = i * chunk_len - end = start + chunk_len - x_chunk = x[start:end] - y_chunk, cached_len, cached_avg = m.streaming_forward( - x_chunk, - cached_len=cached_len, - cached_avg=cached_avg, - ) - assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) - - -def _test_state_stack_unstack(): - m = Zipformer( - num_features=80, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - zipformer_downsampling_factors=(4, 8), - num_left_chunks=2, - decode_chunk_size=8, - ) - s1 = m.get_init_state() - s2 = m.get_init_state() - states = stack_states([s1, s2]) - new_s1, new_s2 = unstack_states(states) - for i in range(m.num_encoders * 7): - for x, y in zip(s1[i], new_s1[i]): - assert torch.equal(x, y) - for x, y in zip(s2[i], new_s2[i]): - assert torch.equal(x, y) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main() - _test_conv2d_subsampling() - _test_pooling_module() - _test_state_stack_unstack() diff --git a/egs/libricss/SURT/dprnn_zipformer/zipformer.py b/egs/libricss/SURT/dprnn_zipformer/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/libricss/SURT/dprnn_zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file