From 35aab79e3dce9d6ef15cb13ea8c34a84d2b3107f Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 10 Oct 2023 16:40:54 +0800 Subject: [PATCH] support exporting model state dict --- .../zipformer_prompt_asr/export_PromptASR.py | 255 ++++++ .../ASR/zipformer_prompt_asr/scaling.py | 824 +++++++++--------- 2 files changed, 674 insertions(+), 405 deletions(-) create mode 100644 egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py b/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py new file mode 100644 index 000000000..b0cb1c46f --- /dev/null +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/export_PromptASR.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +""" +Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export_PromptASR.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export_PromptASR.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from torch import Tensor, nn +from train_bert_encoder import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + assert params.jit is False, "Jit is not supported yet" + + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py index 0f318e33c..0e6764ba0 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py @@ -16,45 +16,45 @@ import collections +import logging +import math +import random +from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union -from functools import reduce -import logging + import k2 -from torch.cuda.amp import custom_fwd, custom_bwd -import random import torch -import math import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn import Embedding as ScaledEmbedding - class PiecewiseLinear(object): """ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y] respectively. """ + def __init__(self, *args): assert len(args) >= 1 if len(args) == 1 and isinstance(args[0], PiecewiseLinear): self.pairs = list(args[0].pairs) else: - self.pairs = [ (float(x), float(y)) for x,y in args ] - for (x,y) in self.pairs: + self.pairs = [(float(x), float(y)) for x, y in args] + for (x, y) in self.pairs: assert isinstance(x, float) or isinstance(x, int) assert isinstance(y, float) or isinstance(y, int) for i in range(len(self.pairs) - 1): assert self.pairs[i + 1][0] > self.pairs[i][0], self.pairs - def __str__(self): # e.g. 'PiecewiseLinear((0., 10.), (100., 0.))' - return f'PiecewiseLinear({str(self.pairs)[1:-1]})' + return f"PiecewiseLinear({str(self.pairs)[1:-1]})" def __call__(self, x): if x <= self.pairs[0][0]: @@ -71,37 +71,36 @@ class PiecewiseLinear(object): assert False def __mul__(self, alpha): - return PiecewiseLinear( - * [(x, y * alpha) for x, y in self.pairs]) + return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs]) def __add__(self, x): if isinstance(x, float) or isinstance(x, int): - return PiecewiseLinear( - * [(p[0], p[1] + x) for p in self.pairs]) + return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs]) s, x = self.get_common_basis(x) return PiecewiseLinear( - * [(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)] + ) def max(self, x): if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def min(self, x): if isinstance(x, float) or isinstance(x, int): - x = PiecewiseLinear( (0, x) ) + x = PiecewiseLinear((0, x)) s, x = self.get_common_basis(x, include_crossings=True) return PiecewiseLinear( - * [ (sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]) + *[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)] + ) def __eq__(self, other): return self.pairs == other.pairs - def get_common_basis(self, - p: 'PiecewiseLinear', - include_crossings: bool = False): + def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): """ Returns (self_mod, p_mod) which are equivalent piecewise lienar functions to self and p, but with the same x values. @@ -113,29 +112,30 @@ class PiecewiseLinear(object): assert isinstance(p, PiecewiseLinear) # get sorted x-values without repetition. - x_vals = sorted(set([ x for x, y in self.pairs ] + [ x for x, y in p.pairs ])) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] + x_vals = sorted(set([x for x, y in self.pairs] + [x for x, y in p.pairs])) + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] if include_crossings: extra_x_vals = [] for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i+1] > y_vals2[i+1]): + if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): # if the two lines in this subsegment potentially cross each other.. diff_cur = abs(y_vals1[i] - y_vals2[i]) - diff_next = abs(y_vals1[i+1] - y_vals2[i+1]) + diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) # `pos`, between 0 and 1, gives the relative x position, # with 0 being x_vals[i] and 1 being x_vals[i+1]. pos = diff_cur / (diff_cur + diff_next) - extra_x_val = x_vals[i] + pos * (x_vals[i+1] - x_vals[i]) + extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i]) extra_x_vals.append(extra_x_val) if len(extra_x_vals) > 0: x_vals = sorted(set(x_vals + extra_x_vals)) - y_vals1 = [ self(x) for x in x_vals ] - y_vals2 = [ p(x) for x in x_vals ] - return ( PiecewiseLinear(* zip(x_vals, y_vals1)), - PiecewiseLinear(* zip(x_vals, y_vals2)) ) - + y_vals1 = [self(x) for x in x_vals] + y_vals2 = [p(x) for x in x_vals] + return ( + PiecewiseLinear(*zip(x_vals, y_vals1)), + PiecewiseLinear(*zip(x_vals, y_vals2)), + ) class ScheduledFloat(torch.nn.Module): @@ -155,9 +155,8 @@ class ScheduledFloat(torch.nn.Module): `default` is used when self.batch_count is not set or in training or mode or in torch.jit scripting mode. """ - def __init__(self, - *args, - default: float = 0.0): + + def __init__(self, *args, default: float = 0.0): super().__init__() # self.batch_count and self.name will be written to in the training loop. self.batch_count = None @@ -166,7 +165,9 @@ class ScheduledFloat(torch.nn.Module): self.schedule = PiecewiseLinear(*args) def extra_repr(self) -> str: - return f'batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}' + return ( + f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" + ) def __float__(self): batch_count = self.batch_count @@ -175,40 +176,39 @@ class ScheduledFloat(torch.nn.Module): else: ans = self.schedule(self.batch_count) if random.random() < 0.0002: - logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") + logging.info( + f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}" + ) return ans def __add__(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule + x, - default=self.default) + return ScheduledFloat(self.schedule + x, default=self.default) else: - return ScheduledFloat(self.schedule + x.schedule, - default=self.default+x.default) + return ScheduledFloat( + self.schedule + x.schedule, default=self.default + x.default + ) def max(self, x): if isinstance(x, float) or isinstance(x, int): - return ScheduledFloat(self.schedule.max(x), - default=self.default) + return ScheduledFloat(self.schedule.max(x), default=self.default) else: - return ScheduledFloat(self.schedule.max(x.schedule), - default=max(self.default, x.default)) + return ScheduledFloat( + self.schedule.max(x.schedule), default=max(self.default, x.default) + ) FloatLike = Union[float, ScheduledFloat] - - -def random_cast_to_half(x: Tensor, - min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = (x_abs < min_abs) + is_too_small = x_abs < min_abs # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -216,7 +216,6 @@ def random_cast_to_half(x: Tensor, return torch.where(is_too_small, random_val, x).to(torch.float16) - class CutoffEstimator: """ Estimates cutoffs of an arbitrary numerical quantity such that a specified @@ -224,6 +223,7 @@ class CutoffEstimator: p is the proportion of items that should be above the cutoff. """ + def __init__(self, p: float): self.p = p # total count of items @@ -233,12 +233,11 @@ class CutoffEstimator: # initial cutoff value self.cutoff = 0 - def __call__(self, x: float) -> bool: """ Returns true if x is above the cutoff. """ - ans = (x > self.cutoff) + ans = x > self.cutoff self.count += 1 if ans: self.count_above += 1 @@ -246,16 +245,16 @@ class CutoffEstimator: delta_p = cur_p - self.p if (delta_p > 0) == ans: q = abs(delta_p) - self.cutoff = x * q + self.cutoff * (1-q) + self.cutoff = x * q + self.cutoff * (1 - q) return ans - class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ + @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -271,7 +270,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -280,29 +279,25 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None - -def softmax(x: Tensor, - dim: int): +def softmax(x: Tensor, dim: int): return SoftmaxFunction.apply(x, dim) class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float, + ) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), - coeffs.detach(), - direction.detach()) + ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) return x - @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -312,20 +307,23 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x ** 2).mean() + x_var = (x**2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual ** 2).mean() + x_residual_var = (x_residual**2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) + x_extra_grad = ( + x_orig.grad + * ctx.grad_scale + * x_grad.norm() + / (x_orig_grad.norm() + 1.0e-20) + ) return x_grad + x_extra_grad.detach(), None, None, None, None - - class BiasNormFunction(torch.autograd.Function): # This computes: # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() @@ -334,8 +332,14 @@ class BiasNormFunction(torch.autograd.Function): # it can just store the returned value (chances are, this will also be needed for # some other reason, related to the next operation, so we can save memory). @staticmethod - def forward(ctx, x: Tensor, bias: Tensor, log_scale: Tensor, channel_dim: int, - store_output_for_backprop: bool) -> Tensor: + def forward( + ctx, + x: Tensor, + bias: Tensor, + log_scale: Tensor, + channel_dim: int, + store_output_for_backprop: bool, + ) -> Tensor: assert bias.ndim == 1 if channel_dim < 0: channel_dim = channel_dim + x.ndim @@ -343,10 +347,16 @@ class BiasNormFunction(torch.autograd.Function): ctx.channel_dim = channel_dim for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = (torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales - ctx.save_for_backward(ans.detach() if store_output_for_backprop else x, - scales.detach(), bias.detach(), log_scale.detach()) + ctx.save_for_backward( + ans.detach() if store_output_for_backprop else x, + scales.detach(), + bias.detach(), + log_scale.detach(), + ) return ans @staticmethod @@ -362,13 +372,14 @@ class BiasNormFunction(torch.autograd.Function): log_scale.requires_grad = True with torch.enable_grad(): # recompute scales from x, bias and log_scale. - scales = (torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5) * log_scale.exp() + scales = ( + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + ) * log_scale.exp() ans = x * scales ans.backward(gradient=ans_grad) return x.grad, bias.grad.flatten(), log_scale.grad, None, None - class BiasNorm(torch.nn.Module): """ This is intended to be a simpler, and hopefully cheaper, replacement for @@ -399,14 +410,15 @@ class BiasNorm(torch.nn.Module): than the input of this module to be required to be stored for the backprop. """ + def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - log_scale: float = 1.0, - log_scale_min: float = -1.5, - log_scale_max: float = 1.5, - store_output_for_backprop: bool = False + self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + log_scale: float = 1.0, + log_scale_min: float = -1.5, + log_scale_max: float = 1.5, + store_output_for_backprop: bool = False, ) -> None: super(BiasNorm, self).__init__() self.num_channels = num_channels @@ -422,7 +434,6 @@ class BiasNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - if torch.jit.is_scripting(): channel_dim = self.channel_dim if channel_dim < 0: @@ -430,24 +441,24 @@ class BiasNorm(torch.nn.Module): bias = self.bias for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) - scales = ((torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5) * - self.log_scale.exp()) + scales = ( + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + ) * self.log_scale.exp() return x * scales - log_scale = limit_param_value(self.log_scale, - min=float(self.log_scale_min), - max=float(self.log_scale_max), - training=self.training) + log_scale = limit_param_value( + self.log_scale, + min=float(self.log_scale_min), + max=float(self.log_scale_max), + training=self.training, + ) - return BiasNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) + return BiasNormFunction.apply( + x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + ) - -def ScaledLinear(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Linear: +def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -466,16 +477,11 @@ def ScaledLinear(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans - -def ScaledConv1d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv1d: +def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -494,15 +500,11 @@ def ScaledConv1d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans -def ScaledConv2d(*args, - initial_scale: float = 1.0, - **kwargs ) -> nn.Conv2d: +def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: """ Behaves like a constructor of a modified version of nn.Conv2d that gives an easy way to set the default initial parameter scale. @@ -522,9 +524,7 @@ def ScaledConv2d(*args, with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) return ans @@ -552,29 +552,36 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): Another option, if you want to do something like this, is to re-initialize the parameters. """ - def __init__(self, - channels: int, - kernel_size: int, - initial_scale: float = 1.0, - bias: bool = True): + + def __init__( + self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True, + ): super().__init__() assert kernel_size % 2 == 1 half_kernel_size = (kernel_size + 1) // 2 # will pad manually, on one side. - self.causal_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=half_kernel_size, - padding=0, - bias=True) + self.causal_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True, + ) - self.chunkwise_conv = nn.Conv1d(in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2, - bias=bias) + self.chunkwise_conv = nn.Conv1d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias, + ) # first row is correction factors added to the scale near the left edge of the chunk, # second row is correction factors added to the scale near the right edge of the chunk, @@ -586,18 +593,15 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): self.causal_conv.weight[:] *= initial_scale self.chunkwise_conv.weight[:] *= initial_scale if bias: - torch.nn.init.uniform_(self.causal_conv.bias, - -0.1 * initial_scale, - 0.1 * initial_scale) + torch.nn.init.uniform_( + self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) - - def forward(self, - x: Tensor, - chunk_size: int = -1) -> Tensor: + def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: """ - Forward function. Args: - x: a Tensor of shape (batch_size, channels, seq_len) - chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. """ (batch_size, num_channels, seq_len) = x.shape @@ -613,28 +617,32 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): x = torch.nn.functional.pad(x, (left_pad, right_pad)) - x_causal = self.causal_conv(x[..., :left_pad + seq_len]) + x_causal = self.causal_conv(x[..., : left_pad + seq_len]) assert x_causal.shape == (batch_size, num_channels, seq_len) x_chunk = x[..., left_pad:] num_chunks = x_chunk.shape[2] // chunk_size x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) - x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, - num_channels, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( + batch_size * num_chunks, num_channels, chunk_size + ) x_chunk = self.chunkwise_conv(x_chunk) # does not change shape chunk_scale = self._get_chunk_scale(chunk_size) x_chunk = x_chunk * chunk_scale - x_chunk = x_chunk.reshape(batch_size, num_chunks, - num_channels, chunk_size).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] + x_chunk = x_chunk.reshape( + batch_size, num_chunks, num_channels, chunk_size + ).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ + ..., :seq_len + ] return x_chunk + x_causal def _get_chunk_scale(self, chunk_size: int): """Returns tensor of shape (num_channels, chunk_size) that will be used to - scale the output of self.chunkwise_conv.""" + scale the output of self.chunkwise_conv.""" left_edge = self.chunkwise_conv_scale[0] right_edge = self.chunkwise_conv_scale[1] if chunk_size < self.kernel_size: @@ -643,30 +651,25 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): else: t = chunk_size - self.kernel_size channels = left_edge.shape[0] - pad = torch.zeros(channels, t, - device=left_edge.device, - dtype=left_edge.dtype) + pad = torch.zeros( + channels, t, device=left_edge.device, dtype=left_edge.dtype + ) left_edge = torch.cat((left_edge, pad), dim=-1) right_edge = torch.cat((pad, right_edge), dim=-1) return 1.0 + (left_edge + right_edge) - - - - - class BalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min_mean: float, - max_mean: float, - min_rms: float, - max_rms: float, - grad_scale: float, - channel_dim: int, + ctx, + x: Tensor, + min_mean: float, + max_mean: float, + min_rms: float, + max_rms: float, + grad_scale: float, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim @@ -675,23 +678,19 @@ class BalancerFunction(torch.autograd.Function): ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) return x - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None]: - x, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + (x,) = ctx.saved_tensors (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config - try: with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True - mean_dims = [ i for i in range(x.ndim) if i != channel_dim ] - uncentered_var = (x ** 2).mean(dim=mean_dims, keepdim=True) + mean_dims = [i for i in range(x.ndim) if i != channel_dim] + uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) mean = x.mean(dim=mean_dims, keepdim=True) stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() rms = uncentered_var.clamp(min=1.0e-20).sqrt() @@ -705,11 +704,16 @@ class BalancerFunction(torch.autograd.Function): rms_clamped = rms.clamp(min=min_rms, max=max_rms) r_loss = (rms_clamped / rms).log().abs() - loss = (m_loss + r_loss) + loss = m_loss + r_loss loss.backward(gradient=torch.ones_like(loss)) loss_grad = x.grad - loss_grad_rms = (loss_grad ** 2).mean(dim=mean_dims, keepdim=True).sqrt().clamp(min=1.0e-20) + loss_grad_rms = ( + (loss_grad**2) + .mean(dim=mean_dims, keepdim=True) + .sqrt() + .clamp(min=1.0e-20) + ) loss_grad = loss_grad * (grad_scale / loss_grad_rms) @@ -720,7 +724,9 @@ class BalancerFunction(torch.autograd.Function): x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) x_grad = x_grad_mod.to(x_grad.dtype) except Exception as e: - logging.info(f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None, None, None, None, None, None @@ -756,16 +762,17 @@ class Balancer(torch.nn.Module): on each forward(). This is done randomly to prevent all layers from doing it at the same time. """ + def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: FloatLike = 0.05, - max_positive: FloatLike = 0.95, - min_abs: FloatLike = 0.2, - max_abs: FloatLike = 100.0, - grad_scale: FloatLike = 0.04, - prob: Optional[FloatLike] = None, + self, + num_channels: int, + channel_dim: int, + min_positive: FloatLike = 0.05, + max_positive: FloatLike = 0.95, + min_abs: FloatLike = 0.2, + max_abs: FloatLike = 100.0, + grad_scale: FloatLike = 0.04, + prob: Optional[FloatLike] = None, ): super().__init__() @@ -786,8 +793,11 @@ class Balancer(torch.nn.Module): self.grad_scale = grad_scale def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad or - (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))): + if ( + torch.jit.is_scripting() + or not x.requires_grad + or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + ): return _no_op(x) prob = float(self.prob) @@ -799,12 +809,14 @@ class Balancer(torch.nn.Module): # for normally distributed data, if the expected absolute value is x, the # expected rms value will be sqrt(pi/2) * x. return 1.25331413732 * x + def _proportion_positive_to_mean(x): def _atanh(x): eps = 1.0e-10 # eps is to prevent crashes if x is exactly 0 or 1. # we'll just end up returning a fairly large value. - return (math.log (1+x+eps) - math.log (1-x+eps)) / 2. + return (math.log(1 + x + eps) - math.log(1 - x + eps)) / 2.0 + def _approx_inverse_erf(x): # 1 / (sqrt(pi) * ln(2)), # see https://math.stackexchange.com/questions/321569/approximating-the-error-function-erf-by-analytical-functions @@ -814,6 +826,7 @@ class Balancer(torch.nn.Module): # and math.erf(0.0407316414078772) = 0.045935330944660666, # which is pretty close to 0.05. return 0.8139535143 * _atanh(x) + # first convert x from the range 0..1 to the range -1..1 which the error # function returns x = -1 + (2 * x) @@ -834,8 +847,9 @@ class Balancer(torch.nn.Module): return _no_op(x) -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, - name: str = None) -> Tensor: +def penalize_abs_values_gt( + x: Tensor, limit: float, penalty: float, name: str = None +) -> Tensor: """ Returns x unmodified, but in backprop will put a penalty for the excess of the absolute values of elements of x over the limit "limit". E.g. if @@ -871,14 +885,12 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, ::dim+1] + x = x[:, :: dim + 1] assert x.shape == (batch, dim) return x - -def _whitening_metric(x: Tensor, - num_groups: int): +def _whitening_metric(x: Tensor, num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -908,25 +920,22 @@ def _whitening_metric(x: Tensor, # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - x: Tensor, - module: nn.Module) -> Tensor: + def forward(ctx, x: Tensor, module: nn.Module) -> Tensor: ctx.save_for_backward(x) ctx.module = module return x @staticmethod - def backward(ctx, - x_grad: Tensor): - x_orig, = ctx.saved_tensors + def backward(ctx, x_grad: Tensor): + (x_orig,) = ctx.saved_tensors w = ctx.module try: @@ -938,8 +947,10 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, w.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info(f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}") + logging.info( + f"Whitening: name={w.name}, num_groups={w.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={float(w.whitening_limit)}" + ) if metric < float(w.whitening_limit): w.prob = w.min_prob @@ -948,23 +959,27 @@ class WhiteningPenaltyFunction(torch.autograd.Function): w.prob = w.max_prob metric.backward() penalty_grad = x_detached.grad - scale = w.grad_scale * (x_grad.to(torch.float32).norm() / - (penalty_grad.norm() + 1.0e-20)) + scale = w.grad_scale * ( + x_grad.to(torch.float32).norm() + / (penalty_grad.norm() + 1.0e-20) + ) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None except Exception as e: - logging.info(f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue.") + logging.info( + f"Caught exception in Whiten backward: {e}, size={list(x_grad.shape)}, will continue." + ) return x_grad, None - class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: FloatLike, - prob: Union[float, Tuple[float,float]], - grad_scale: FloatLike): + self, + num_groups: int, + whitening_limit: FloatLike, + prob: Union[float, Tuple[float, float]], + grad_scale: FloatLike, + ): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -996,10 +1011,9 @@ class Whiten(nn.Module): (self.min_prob, self.max_prob) = prob assert 0 < self.min_prob <= self.max_prob <= 1 self.prob = self.max_prob - self.name = None # will be set in training loop + self.name = None # will be set in training loop - def forward(self, - x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -1031,11 +1045,16 @@ class WithLoss(torch.autograd.Function): loss_sum = y.sum().item() logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}") return x + @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones(ctx.y_shape, - dtype=ans_grad.dtype, - device=ans_grad.device), None + return ( + ans_grad, + torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + None, + ) + + def with_loss(x, y, name): # returns x but adds y.sum() to the loss function. return WithLoss.apply(x, y, name) @@ -1046,13 +1065,16 @@ class ScaleGradFunction(torch.autograd.Function): def forward(ctx, x: Tensor, alpha: float) -> Tensor: ctx.alpha = alpha return x + @staticmethod def backward(ctx, grad: Tensor): return grad * ctx.alpha, None + def scale_grad(x: Tensor, alpha: float): return ScaleGradFunction.apply(x, alpha) + class ScaleGrad(nn.Module): def __init__(self, alpha: float): super().__init__() @@ -1061,6 +1083,7 @@ class ScaleGrad(nn.Module): def forward(self, x: Tensor) -> Tensor: return scale_grad(x, self.alpha) + class LimitParamValue(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor, min: float, max: float): @@ -1069,21 +1092,24 @@ class LimitParamValue(torch.autograd.Function): ctx.min = min ctx.max = max return x + @staticmethod def backward(ctx, x_grad: Tensor): - x, = ctx.saved_tensors + (x,) = ctx.saved_tensors # where x < ctx.min, ensure all grads are negative (this will tend to make # x more positive). - x_grad = x_grad * torch.where(torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0) + x_grad = x_grad * torch.where( + torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0 + ) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) return x_grad, None, None -def limit_param_value(x: Tensor, - min: float, max: float, - prob: float = 0.6, - training: bool = True): + +def limit_param_value( + x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True +): # You apply this to (typically) an nn.Parameter during training to ensure that its # (elements mostly) stays within a supplied range. This is done by modifying the # gradients in backprop. @@ -1096,7 +1122,7 @@ def limit_param_value(x: Tensor, def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting(): return x else: # a no-op function that will have a node in the autograd graph, @@ -1140,7 +1166,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = (y * (1 - s) + s) + deriv = y * (1 - s) + s # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 @@ -1150,7 +1176,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.044 ceil = 1.2 - d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) + d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + deriv + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1163,19 +1191,19 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = (d * ((ceil - floor) / 255.0) + floor) + d = d * ((ceil - floor) / 255.0) + floor return y_grad * d + class DoubleSwish(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). @@ -1185,23 +1213,22 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) - # Dropout2 is just like normal dropout, except it supports schedules on the dropout rates. class Dropout2(nn.Module): def __init__(self, p: FloatLike): super().__init__() self.p = p + def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, - p=float(self.p), - training=self.training) + return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) + class MulForDropout3(torch.autograd.Function): # returns (x * y * alpha) where alpha is a float and y doesn't require # grad and is zero-or-one. @staticmethod @custom_fwd - def forward(ctx, x, y, alpha): + def forward(ctx, x, y, alpha): assert not y.requires_grad ans = x * y * alpha ctx.save_for_backward(ans) @@ -1211,10 +1238,11 @@ class MulForDropout3(torch.autograd.Function): @staticmethod @custom_bwd def backward(ctx, ans_grad): - ans, = ctx.saved_tensors + (ans,) = ctx.saved_tensors x_grad = ctx.alpha * ans_grad * (ans != 0) return x_grad, None, None + # Dropout3 is just like normal dropout, except it supports schedules on the dropout rates, # and it lets you choose one dimension to share the dropout mask over class Dropout3(nn.Module): @@ -1222,6 +1250,7 @@ class Dropout3(nn.Module): super().__init__() self.p = p self.shared_dim = shared_dim + def forward(self, x: Tensor) -> Tensor: p = float(self.p) if not self.training or p == 0: @@ -1236,7 +1265,7 @@ class Dropout3(nn.Module): class SwooshLFunction(torch.autograd.Function): """ - swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 + swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035 """ @staticmethod @@ -1255,17 +1284,19 @@ class SwooshLFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 + y = torch.logaddexp(zero, x - 4.0) + coeff * x - 0.035 if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = coeff ceil = 1.0 + coeff + 0.005 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1279,35 +1310,34 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. coeff = -0.08 floor = coeff ceil = 1.0 + coeff + 0.005 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-L activation. - """ + """Return Swoosh-L activation.""" if torch.jit.is_scripting(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 if not x.requires_grad: return k2.swoosh_l_forward(x) else: return k2.swoosh_l(x) - #return SwooshLFunction.apply(x) + # return SwooshLFunction.apply(x) class SwooshRFunction(torch.autograd.Function): """ - swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 + swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 - derivatives are between -0.08 and 0.92. + derivatives are between -0.08 and 0.92. """ @@ -1325,17 +1355,19 @@ class SwooshRFunction(torch.autograd.Function): with torch.enable_grad(): x = x.detach() x.requires_grad = True - y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + y = torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not requires_grad: return y - y.backward(gradient = torch.ones_like(y)) + y.backward(gradient=torch.ones_like(y)) grad = x.grad floor = -0.08 ceil = 0.925 - d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad)) + d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( + grad + ) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1348,21 +1380,20 @@ class SwooshRFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - d, = ctx.saved_tensors + (d,) = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.08 ceil = 0.925 - d = (d * ((ceil - floor) / 255.0) + floor) - return (y_grad * d) + d = d * ((ceil - floor) / 255.0) + floor + return y_grad * d class SwooshR(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: - """Return Swoosh-R activation. - """ + """Return Swoosh-R activation.""" if torch.jit.is_scripting(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + return torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 if not x.requires_grad: return k2.swoosh_r_forward(x) else: @@ -1375,7 +1406,7 @@ class SwooshR(torch.nn.Module): def SwooshLForward(x: Tensor): x_offset = x - 4.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.035 @@ -1384,28 +1415,30 @@ def SwooshLForward(x: Tensor): def SwooshRForward(x: Tensor): x_offset = x - 1.0 log_sum = (1.0 + x_offset.exp()).log().to(x.dtype) - log_sum = torch.where(log_sum == float('inf'), x_offset, log_sum) + log_sum = torch.where(log_sum == float("inf"), x_offset, log_sum) return log_sum - 0.08 * x - 0.313261687 class ActivationDropoutAndLinearFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, - x: Tensor, - weight: Tensor, - bias: Optional[Tensor], - activation: str, - dropout_p: float, - dropout_shared_dim: Optional[int]): + def forward( + ctx, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + activation: str, + dropout_p: float, + dropout_shared_dim: Optional[int], + ): if dropout_p != 0.0: dropout_shape = list(x.shape) if dropout_shared_dim is not None: dropout_shape[dropout_shared_dim] = 1 # else it won't be very memory efficient. - dropout_mask = ((1.0 / (1.0 - dropout_p)) * - (torch.rand(*dropout_shape, - device=x.device, dtype=x.dtype) > dropout_p)) + dropout_mask = (1.0 / (1.0 - dropout_p)) * ( + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + ) else: dropout_mask = None @@ -1414,8 +1447,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): ctx.activation = activation forward_activation_dict = { - 'SwooshL': k2.swoosh_l_forward, - 'SwooshR': k2.swoosh_r_forward + "SwooshL": k2.swoosh_l_forward, + "SwooshR": k2.swoosh_r_forward, } # it will raise a KeyError if this fails. This will be an error. We let it # propagate to the user. @@ -1433,8 +1466,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): (x, weight, bias, dropout_mask) = saved forward_and_deriv_activation_dict = { - 'SwooshL': k2.swoosh_l_forward_and_deriv, - 'SwooshR': k2.swoosh_r_forward_and_deriv + "SwooshL": k2.swoosh_l_forward_and_deriv, + "SwooshR": k2.swoosh_r_forward_and_deriv, } # the following lines a KeyError if the activation is unrecognized. # This will be an error. We let it propagate to the user. @@ -1449,8 +1482,7 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): in_channels = y.shape[-1] g = ans_grad.reshape(-1, out_channels) - weight_deriv = torch.matmul(g.t(), - y.reshape(-1, in_channels)) + weight_deriv = torch.matmul(g.t(), y.reshape(-1, in_channels)) y_deriv = torch.matmul(ans_grad, weight) bias_deriv = None if bias is None else g.sum(dim=0) x_deriv = y_deriv * func_deriv @@ -1463,78 +1495,81 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): class ActivationDropoutAndLinear(torch.nn.Module): """ - This merges an activation function followed by dropout and then a nn.Linear module; - it does so in a memory efficient way so that it only stores the input to the whole - module. If activation == SwooshL and dropout_shared_dim != None, this will be - equivalent to: - nn.Sequential(SwooshL(), - Dropout3(dropout_p, shared_dim=dropout_shared_dim), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=initial_scale)) - If dropout_shared_dim is None, the dropout would be equivalent to - Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout - mask is smaller. + This merges an activation function followed by dropout and then a nn.Linear module; + it does so in a memory efficient way so that it only stores the input to the whole + module. If activation == SwooshL and dropout_shared_dim != None, this will be + equivalent to: + nn.Sequential(SwooshL(), + Dropout3(dropout_p, shared_dim=dropout_shared_dim), + ScaledLinear(in_channels, out_channels, bias=bias, + initial_scale=initial_scale)) + If dropout_shared_dim is None, the dropout would be equivalent to + Dropout2(dropout_p). Note: Dropout3 will be more memory efficient as the dropout + mask is smaller. - Args: - in_channels: number of input channels, e.g. 256 - out_channels: number of output channels, e.g. 256 - bias: if true, have a bias - activation: the activation function, for now just support SwooshL. - dropout_p: the dropout probability or schedule (happens after nonlinearity). - dropout_shared_dim: the dimension, if any, across which the dropout mask is - shared (e.g. the time dimension). If None, this may be less memory - efficient if there are modules before this one that cache the input - for their backprop (e.g. Balancer or Whiten). + Args: + in_channels: number of input channels, e.g. 256 + out_channels: number of output channels, e.g. 256 + bias: if true, have a bias + activation: the activation function, for now just support SwooshL. + dropout_p: the dropout probability or schedule (happens after nonlinearity). + dropout_shared_dim: the dimension, if any, across which the dropout mask is + shared (e.g. the time dimension). If None, this may be less memory + efficient if there are modules before this one that cache the input + for their backprop (e.g. Balancer or Whiten). """ - def __init__(self, - in_channels: int, - out_channels: int, - bias: bool = True, - activation: str = 'SwooshL', - dropout_p: FloatLike = 0.0, - dropout_shared_dim: Optional[int] = -1, - initial_scale: float = 1.0): + + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + initial_scale: float = 1.0, + ): super().__init__() # create a temporary module of nn.Linear that we'll steal the # weights and bias from - l = ScaledLinear(in_channels, out_channels, - bias=bias, - initial_scale=initial_scale) + layer = ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=initial_scale + ) - self.weight = l.weight + self.weight = layer.weight # register_parameter properly handles making it a parameter when l.bias # is None. I think there is some reason for doing it this way rather # than just setting it to None but I don't know what it is, maybe # something to do with exporting the module.. - self.register_parameter('bias', l.bias) + self.register_parameter("bias", layer.bias) self.activation = activation self.dropout_p = dropout_p self.dropout_shared_dim = dropout_shared_dim - def forward(self, - x: Tensor): + def forward(self, x: Tensor): if torch.jit.is_scripting(): - if self.activation == 'SwooshL': + if self.activation == "SwooshL": x = SwooshLForward(x) elif self.activation == "SwooshR": x = SwooshRForward(x) else: assert False, self.activation - return torch.nn.functional.linear(x, - self.weight, - self.bias) + return torch.nn.functional.linear(x, self.weight, self.bias) return ActivationDropoutAndLinearFunction.apply( - x, self.weight, self.bias, self.activation, - float(self.dropout_p), self.dropout_shared_dim) + x, + self.weight, + self.bias, + self.activation, + float(self.dropout_p), + self.dropout_shared_dim, + ) + class ClipGradFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - limit: float): + def forward(ctx, x: Tensor, limit: float): ctx.limit = limit return x @@ -1546,15 +1581,14 @@ class ClipGradFunction(torch.autograd.Function): def clip_grad(x: Tensor, limit: float): return ClipGradFunction.apply(x, limit) + class AbsValuePenalizer(nn.Module): """ This module adds a penalty to the loss function when ever the absolute value of any element of the input tensor exceeds a certain limit. """ - def __init__(self, - limit: float, - prob: float = 0.1, - penalty: float = 1.0e-04): + + def __init__(self, limit: float, prob: float = 0.1, penalty: float = 1.0e-04): super().__init__() self.limit = limit self.penalty = penalty @@ -1567,18 +1601,21 @@ class AbsValuePenalizer(nn.Module): self.mem_cutoff = CutoffEstimator(0.2) def forward(self, x: Tensor) -> Tensor: - if (torch.jit.is_scripting() or not x.requires_grad + if ( + torch.jit.is_scripting() + or not x.requires_grad or not self.training - or random.random() > self.prob): + or random.random() > self.prob + ): # or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) return _no_op(x) # the _no_op op is to make our diagnostics code work. - x = penalize_abs_values_gt(x, - limit=self.limit, - penalty=self.penalty, - name=self.name) + x = penalize_abs_values_gt( + x, limit=self.limit, penalty=self.penalty, name=self.name + ) return x + def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: if num_channels <= x.shape[-1]: return x[..., :num_channels] @@ -1600,11 +1637,9 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten(1, # num_groups - 5.0, # whitening_limit, - prob=1.0, - grad_scale=0.1) # grad_scale - + m = Whiten( + 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, + ) # grad_scale for _ in range(4): y = m(x) @@ -1642,14 +1677,10 @@ def _test_balancer_sign(): print("_test_balancer_sign: x grad = ", x.grad) - - def _test_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( - -1 - ) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True m = Balancer( @@ -1671,29 +1702,26 @@ def _test_balancer_magnitude(): print("_test_balancer_magnitude: x grad = ", x.grad) - - def _test_double_swish_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True m = DoubleSwish() - tol = ((1.2-(-0.043637))/255.0) + tol = (1.2 - (-0.043637)) / 255.0 torch.autograd.gradcheck(m, x, atol=tol) - # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) + def _test_swooshl_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True m = SwooshL() - - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1701,12 +1729,13 @@ def _test_swooshl_deriv(): x.requires_grad = True y = m(x) + def _test_swooshr_deriv(): x = torch.randn(10, 12, dtype=torch.double) * 3.0 x.requires_grad = True m = SwooshR() - tol = (1.0 / 255.0) + tol = 1.0 / 255.0 torch.autograd.gradcheck(m, x, atol=tol, eps=0.01) # for self-test. @@ -1715,55 +1744,29 @@ def _test_swooshr_deriv(): y = m(x) - def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:,0].sum().backward() + a.softmax(dim=1)[:, 0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:,0].sum().backward() + softmax(b, dim=1)[:, 0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) -def _test_caching_eval(): - m = nn.Sequential(nn.Linear(10, 100, bias=False), nn.ReLU(), nn.Linear(100, 10, bias=False)) - - x = torch.randn(50, 10) - y_grad = torch.randn(50, 10) - x.requires_grad = True - - y1 = m(x) - y1.backward(gradient=y_grad) - x_grad1 = x.grad - x.grad = None - weight_grad1a = m[0].weight.grad - weight_grad1b = m[2].weight.grad - m[0].weight.grad = None - m[2].weight.grad = None - m.zero_grad() - - y2 = caching_eval(m, x) - assert torch.allclose(y1, y2) - y2.backward(gradient=y_grad) - assert torch.allclose(x.grad, x_grad1) - assert torch.allclose(m[0].weight.grad, weight_grad1a) - assert torch.allclose(m[2].weight.grad, weight_grad1b) - - def _test_piecewise_linear(): - p = PiecewiseLinear( (0, 10.0) ) + p = PiecewiseLinear((0, 10.0)) for x in [-100, 0, 100]: assert p(x) == 10.0 - p = PiecewiseLinear( (0, 10.0), (1, 0.0) ) - for x, y in [ (-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0) ]: + p = PiecewiseLinear((0, 10.0), (1, 0.0)) + for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]: print("x, y = ", x, y) assert p(x) == y, (x, p(x), y) q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0)) - x_vals = [ -1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0 ] + x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0] pq = p.max(q) for x in x_vals: y1 = max(p(x), q(x)) @@ -1776,7 +1779,7 @@ def _test_piecewise_linear(): assert abs(y1 - y2) < 0.001 pq = p + q for x in x_vals: - y1 = p(x) + q(x) + y1 = p(x) + q(x) y2 = pq(x) assert abs(y1 - y2) < 0.001 @@ -1791,15 +1794,22 @@ def _test_activation_dropout_and_linear(): # swoosh_l an swoosh_r inside SwooshL() and SwooshR(), and they call randn() # internally, messing up the random state. for dropout_p in [0.0]: - for activation in ['SwooshL', 'SwooshR']: - m1 = nn.Sequential(SwooshL() if activation == 'SwooshL' else SwooshR(), - Dropout3(p=dropout_p, shared_dim=-1), - ScaledLinear(in_channels, out_channels, bias=bias, - initial_scale=0.5)) - m2 = ActivationDropoutAndLinear(in_channels, out_channels, - bias=bias, initial_scale=0.5, - activation=activation, - dropout_p=dropout_p) + for activation in ["SwooshL", "SwooshR"]: + m1 = nn.Sequential( + SwooshL() if activation == "SwooshL" else SwooshR(), + Dropout3(p=dropout_p, shared_dim=-1), + ScaledLinear( + in_channels, out_channels, bias=bias, initial_scale=0.5 + ), + ) + m2 = ActivationDropoutAndLinear( + in_channels, + out_channels, + bias=bias, + initial_scale=0.5, + activation=activation, + dropout_p=dropout_p, + ) with torch.no_grad(): m2.weight[:] = m1[2].weight if bias: @@ -1809,9 +1819,9 @@ def _test_activation_dropout_and_linear(): x1.requires_grad = True # TEMP. - assert torch.allclose(SwooshRFunction.apply(x1), - SwooshRForward(x1), - atol=1.0e-03) + assert torch.allclose( + SwooshRFunction.apply(x1), SwooshRForward(x1), atol=1.0e-03 + ) x2 = x1.clone().detach() x2.requires_grad = True @@ -1824,20 +1834,24 @@ def _test_activation_dropout_and_linear(): y2 = m2(x2) y2.backward(gradient=y_grad) - print(f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}") + print( + f"bias = {bias}, dropout_p = {dropout_p}, activation = {activation}" + ) print("y1 = ", y1) print("y2 = ", y2) assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, - atol=1.0e-05) + assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) print("x1.grad = ", x1.grad) print("x2.grad = ", x2.grad) + def isclose(a, b): # return true if cosine similarity is > 0.9. - return (a * b).sum() > 0.9 * ((a**2).sum() * (b**2).sum()).sqrt() + return (a * b).sum() > 0.9 * ( + (a**2).sum() * (b**2).sum() + ).sqrt() + # the SwooshL() implementation has a noisy gradient due to 1-byte # storage of it. assert isclose(x1.grad, x2.grad)