From 99e9d6c4b8ab035d9c1962fc5b6086586d336090 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 4 Apr 2022 13:37:10 +0800 Subject: [PATCH] Some cleanups --- .../ASR/conformer_ctc/subsampling.py | 422 +----------------- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 2 - .../ASR/transducer_stateless/diagnostics.py | 338 -------------- 3 files changed, 5 insertions(+), 757 deletions(-) delete mode 100644 egs/librispeech/ASR/transducer_stateless/diagnostics.py diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 0a39b0f33..542fb0364 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -17,8 +17,6 @@ import torch import torch.nn as nn -from torch import Tensor -from typing import Tuple class Conv2dSubsampling(nn.Module): @@ -44,27 +42,16 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - ScaledConv2d( + nn.Conv2d( in_channels=1, out_channels=odim, kernel_size=3, stride=2 ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( + nn.ReLU(), + nn.Conv2d( in_channels=odim, out_channels=odim, kernel_size=3, stride=2 ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), + nn.ReLU(), ) - self.out = ScaledLinear(odim * (((idim - 1) // 2 - 1) // 2), odim) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(odim, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer(channel_dim=-1, - min_positive=0.45, - max_positive=0.55) - + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -83,8 +70,6 @@ class Conv2dSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) return x @@ -174,400 +159,3 @@ class VggSubsampling(nn.Module): b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) return x - - - - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - ) -> Tensor: - if x.requires_grad: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - xgt0 = x > 0 - proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) - factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) - if min_positive != 0.0 else 0.0) - factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 else 0.0) - factor = factor1 + factor2 - if isinstance(factor, float): - factor = torch.zeros_like(proportion_positive) - - mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = (mean_abs < min_abs) - above_threshold = (mean_abs > max_abs) - - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) - ctx.max_factor = max_factor - ctx.sum_dims = sum_dims - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: - factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors - dtype = x_grad.dtype - scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * - (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) - - neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None, None, None - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_speed: a constant that determines how fast "eps" learns; - with Adam and variants, this should probably be >= 1, - e.g. 5.0. For SGD and variants, probably a value less than one, - like 0.1, would be suitable, to prevent instability. - """ - def __init__(self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - eps_speed: float = 5.0): - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_speed = eps_speed - if learn_eps: - self.eps = nn.Parameter((torch.tensor(eps).log() / eps_speed).detach()) - else: - self.register_buffer('eps', (torch.tensor(eps).log() / eps_speed).detach()) - - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + - (self.eps * self.eps_speed).exp()) ** -0.5 - return x * scales - - - - -class ScaledLinear(nn.Linear): - """ - A modified version of nn.Linear where the parameters are scaled before - use, via: - weight = self.weight * (self.weight_scale * self.scale_speed).exp() - bias = self.bias * (self.bias_scale * self.scale_speed).exp() - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - scale_speed: a factor that affects how fast the weight_scale - and bias_scale learn; this value is suitable for Adam-type - optimizers. - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - - Note: it uses the default initialization for the weight and bias, - inherited from nn.Linear. For modules with small fan-in, this - may be larger than optimal. - """ - def __init__(self, *args, - scale_speed: float = 5.0, - initial_scale: float = 1.0, - **kwargs): - super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - self.scale_speed = scale_speed - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - - self._reset_parameters() # Overrides the reset_parameters in nn.Linear - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), - self.get_bias()) - - -class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, scale_speed = 5.0, - initial_scale=1.0, **kwargs): - super(ScaledConv1d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - self.get_weight(), self.get_bias(), self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - - - -class ScaledConv2d(nn.Conv2d): - def __init__(self, *args, scale_speed=5.0, initial_scale=1.0, **kwargs): - super(ScaledConv2d, self).__init__(*args, **kwargs) - self.scale_speed = scale_speed - initial_scale = (torch.tensor(initial_scale).log() / scale_speed) - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter('bias_scale', None) - self._reset_parameters() # Overrides the reset_parameters in base class - - def _reset_parameters(self): - std = 0.05 - a = (3 ** 0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in ** -0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += (torch.tensor(scale / std).log() / self.scale_speed) - - - def get_weight(self): - return self.weight * (self.weight_scale * self.scale_speed).exp() - - def get_bias(self): - return (None if self.bias is None else - self.bias * (self.bias_scale * self.scale_speed).exp()) - - def _conv_forward(self, input, weight): - F = torch.nn.functional - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.get_bias(), self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.get_bias(), self.stride, - self.padding, self.dilation, self.groups) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) - - - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - - Args: - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - min_abs: the minimum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - max_abs: the maximum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - """ - def __init__(self, channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0): - super(ActivationBalancer, self).__init__() - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - - def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply(x, self.channel_dim, - self.min_positive, self.max_positive, - self.max_factor, self.min_abs, - self.max_abs) - - -def _double_swish(x: Tensor) -> Tensor: - # double-swish, implemented/approximated as offset-swish - return x * torch.sigmoid(x - 1.0) - -class DoubleSwishFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - ctx.save_for_backward(x.detach()) - return _double_swish(x) - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - # TODO: can make this more efficient. - x, = ctx.saved_tensors - x.requires_grad = True - with torch.enable_grad(): - y = _double_swish(x) - y.backward(gradient=y_grad) - return x.grad - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - return DoubleSwishFunction.apply(x) - - - -def _test_deriv_balancer_sign(): - channel_dim = 0 - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, - max_factor=0.2, min_abs=0.0) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_deriv_balancer_sign: x = ", x) - print("_test_deriv_balancer_sign: y grad = ", y_grad) - print("_test_deriv_balancer_sign: x grad = ", x.grad) - -def _test_deriv_balancer_magnitude(): - channel_dim = 0 - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer(channel_dim=0, - min_positive=0.0, max_positive=1.0, - max_factor=0.2, - min_abs=0.2, max_abs=0.8) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_deriv_balancer_magnitude: x = ", x) - print("_test_deriv_balancer_magnitude: y grad = ", y_grad) - print("_test_deriv_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - - - - -if __name__ == '__main__': - _test_deriv_balancer_sign() - _test_deriv_balancer_magnitude() - _test_basic_norm() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 477afcecb..8dd1459ca 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -22,8 +22,6 @@ import logging from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional -import torch -from lhotse.utils import fix_random_seed import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest diff --git a/egs/librispeech/ASR/transducer_stateless/diagnostics.py b/egs/librispeech/ASR/transducer_stateless/diagnostics.py deleted file mode 100644 index 7fd83d56b..000000000 --- a/egs/librispeech/ASR/transducer_stateless/diagnostics.py +++ /dev/null @@ -1,338 +0,0 @@ -import torch -from torch import Tensor -from torch import nn -import math -import random -from typing import Tuple, List - - -class TensorDiagnosticOptions(object): - """ - Options object for tensor diagnostics: - - Args: - memory_limit: the maximum number of bytes we store per tensor (limits how many copies - of the tensor we cache). - max_eig_dim: the maximum dimension for which we print out eigenvalues - (limited for speed reasons). - """ - def __init__(self, - memory_limit: int = (2 ** 20), - max_eig_dim: int = 512): - - self.memory_limit = memory_limit - self.max_eig_dim = max_eig_dim - - def dim_is_summarized(self, size: int): - return size > 10 and size != 31 - - - -def get_tensor_stats(x: Tensor, dim: int, - stats_type: str) -> Tuple[Tensor, int]: - """ - Returns the specified transformation of the Tensor (either x or x.abs() - or (x > 0), summed over all but the index `dim`. - - Args: - x: Tensor, tensor to be analyzed - dim: dimension with 0 <= dim < x.ndim - stats_type: - "abs" -> take abs() before summing - "positive" -> take (x > 0) before summing - "rms" -> square before summing, we'll take sqrt later - "value -> just sum x itself - Returns (stats, count) - where stats is a Tensor of shape (x.shape[dim],), and the count - is an integer saying how many items were counted in each element - of stats. - """ - count = x.numel() // x.shape[dim] - - if stats_type == "eigs": - x = x.transpose(dim, -1) - x = x.reshape(-1, x.shape[-1]) - # shape of returned tensor: (s, s) where s is size of dimension `dim` of original x. - return torch.matmul(x.transpose(0, 1), x), count - elif stats_type == "abs": - x = x.abs() - elif stats_type == "rms": - x = x ** 2 - elif stats_type == "positive": - x = (x > 0).to(dtype=torch.float) - else: - assert stats_type == "value" - - sum_dims = [ d for d in range(x.ndim) if d != dim ] - if len(sum_dims) > 0: - x = torch.sum(x, dim=sum_dims) - x = x.flatten() - return x, count - -def get_diagnostics_for_dim(dim: int, tensors: List[Tensor], - options: TensorDiagnosticOptions, - sizes_same: bool, - stats_type: str): - """ - This function gets diagnostics for a dimension of a module. - Args: - dim: the dimension to analyze, with 0 <= dim < tensors[0].ndim - options: options object - sizes_same: true if all the tensor sizes are the same on this dimension - stats_type: either "abs" or "positive" or "eigs" or "value", - imdictates the type of stats - we accumulate, abs is mean absolute value, "positive" - is proportion of positive to nonnegative values, "eigs" - is eigenvalues after doing outer product on this dim, sum - over all other dimes. - Returns: - Diagnostic as a string, either percentiles or the actual values, - see the code. Will return the empty string if the diagnostics did - not make sense to print out for this dimension, e.g. dimension - mismatch and stats_type == "eigs" - """ - # stats_and_counts is a list of pair (Tensor, int) - stats_and_counts = [ get_tensor_stats(x, dim, stats_type) for x in tensors ] - stats = [ x[0] for x in stats_and_counts ] - counts = [ x[1] for x in stats_and_counts ] - - if stats_type == "eigs": - try: - stats = torch.stack(stats).sum(dim=0) - except: - return '' - count = sum(counts) - stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance - elif sizes_same: - stats = torch.stack(stats).sum(dim=0) - count = sum(counts) - stats = stats / count - else: - stats = [ x[0] / x[1] for x in stats_and_counts ] - stats = torch.cat(stats, dim=0) - if stats_type == 'rms': - stats = stats.sqrt() - - # if `summarize` we print percentiles of the stats; else, - # we print out individual elements. - summarize = (not sizes_same) or options.dim_is_summarized(stats.numel()) - if summarize: - # print out percentiles. - stats = stats.sort()[0] - num_percentiles = 10 - size = stats.numel() - percentiles = [] - for i in range(num_percentiles + 1): - index = (i * (size - 1)) // num_percentiles - percentiles.append(stats[index].item()) - percentiles = [ '%.2g' % x for x in percentiles ] - percentiles = ' '.join(percentiles) - ans = f'percentiles: [{percentiles}]' - else: - ans = stats.tolist() - ans = [ '%.2g' % x for x in ans ] - ans = '[' + ' '.join(ans) + ']' - if stats_type == "value": - # This norm is useful because it is strictly less than the largest - # sqrt(eigenvalue) of the variance, which we print out, and shows, - # speaking in an approximate way, how much of that largest eigenvalue - # can be attributed to the mean of the distribution. - norm = (stats ** 2).sum().sqrt().item() - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f', norm={norm:.2g}, mean={mean:.2g}, rms={rms:.2g}' - else: - mean = stats.mean().item() - rms = (stats ** 2).mean().sqrt().item() - ans += f', mean={mean:.2g}, rms={rms:.2g}' - return ans - - - -def print_diagnostics_for_dim(name: str, dim: int, tensors: List[Tensor], - options: TensorDiagnosticOptions): - ndim = tensors[0].ndim - if ndim > 1: - stats_types = ["abs", "positive", "value", "rms"] - if tensors[0].shape[dim] <= options.max_eig_dim: - stats_types.append("eigs") - else: - stats_types = [ "value", "abs" ] - - for stats_type in stats_types: - sizes = [ x.shape[dim] for x in tensors ] - sizes_same = all([ x == sizes[0] for x in sizes ]) - s = get_diagnostics_for_dim(dim, tensors, - options, sizes_same, - stats_type) - if s == '': - continue - - min_size = min(sizes) - max_size = max(sizes) - size_str = f"{min_size}" if sizes_same else f"{min_size}..{max_size}" - # stats_type will be "abs" or "positive". - print(f"module={name}, dim={dim}, size={size_str}, {stats_type} {s}") - - -class TensorDiagnostic(object): - """ - This class is not directly used by the user, it is responsible for collecting - diagnostics for a single parameter tensor of a torch.Module. - """ - def __init__(self, - opts: TensorDiagnosticOptions, - name: str): - self.name = name - self.opts = opts - self.saved_tensors = [] - - def accumulate(self, x): - if isinstance(x, Tuple): - x = x[0] - if not isinstance(x, Tensor): - return - if x.device == torch.device('cpu'): - x = x.detach().clone() - else: - x = x.detach().to('cpu', non_blocking=True) - self.saved_tensors.append(x) - l = len(self.saved_tensors) - if l & (l - 1) == 0: # power of 2.. - self._limit_memory() - - def _limit_memory(self): - if len(self.saved_tensors) > 1024: - self.saved_tensors = self.saved_tensors[-1024:] - return - - tot_mem = 0.0 - for i in reversed(range(len(self.saved_tensors))): - tot_mem += self.saved_tensors[i].numel() * self.saved_tensors[i].element_size() - if tot_mem > self.opts.memory_limit: - self.saved_tensors = self.saved_tensors[i:] - return - - def print_diagnostics(self): - if len(self.saved_tensors) == 0: - print("{name}: no stats".format(name=self.name)) - return - if self.saved_tensors[0].ndim == 0: - # ensure there is at least one dim. - self.saved_tensors = [ x.unsqueeze(0) for x in self.saved_tensors ] - - try: - device = torch.device('cuda') - torch.ones(1, 1, device) - except: - device = torch.device('cpu') - - ndim = self.saved_tensors[0].ndim - tensors = [x.to(device) for x in self.saved_tensors] - for dim in range(ndim): - print_diagnostics_for_dim(self.name, dim, - tensors, - self.opts) - - -class ModelDiagnostic(object): - def __init__(self, opts: TensorDiagnosticOptions = TensorDiagnosticOptions()): - self.diagnostics = dict() - self.opts = opts - - def __getitem__(self, name: str): - if name not in self.diagnostics: - self.diagnostics[name] = TensorDiagnostic(self.opts, name) - return self.diagnostics[name] - - def print_diagnostics(self): - for k in sorted(self.diagnostics.keys()): - self.diagnostics[k].print_diagnostics() - - - -def attach_diagnostics(model: nn.Module, - opts: TensorDiagnosticOptions) -> ModelDiagnostic: - ans = ModelDiagnostic(opts) - for name, module in model.named_modules(): - if name == '': - name = "" - forward_diagnostic = TensorDiagnostic(opts, name + ".output") - backward_diagnostic = TensorDiagnostic(opts, name + ".grad") - - - # setting model_diagnostic=ans and n=name below, instead of trying to capture the variables, - # ensures that we use the current values. (matters for name, since - # the variable gets overwritten). these closures don't really capture - # by value, only by "the final value the variable got in the function" :-( - def forward_hook(_module, _input, _output, - _model_diagnostic=ans, _name=name): - if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.output"].accumulate(_output) - elif isinstance(_output, tuple): - for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o) - - def backward_hook(_module, _input, _output, - _model_diagnostic=ans, _name=name): - if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.grad"].accumulate(_output) - elif isinstance(_output, tuple): - for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o) - - module.register_forward_hook(forward_hook) - module.register_backward_hook(backward_hook) - - for name, parameter in model.named_parameters(): - - def param_backward_hook(grad, - _parameter=parameter, - _model_diagnostic=ans, - _name=name): - _model_diagnostic[f"{_name}.param_value"].accumulate(_parameter) - _model_diagnostic[f"{_name}.param_grad"].accumulate(grad) - - parameter.register_hook(param_backward_hook) - return ans - - - -def _test_tensor_diagnostic(): - opts = TensorDiagnosticOptions(2**20, 512) - - diagnostic = TensorDiagnostic(opts, "foo") - - for _ in range(10): - diagnostic.accumulate(torch.randn(50, 100) * 10.0) - - diagnostic.print_diagnostics() - - model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) - - diagnostic = attach_diagnostics(model, opts) - for _ in range(10): - T = random.randint(200, 300) - x = torch.randn(T, 100) - y = model(x) - y.sum().backward() - - diagnostic.print_diagnostics() - - - -if __name__ == '__main__': - _test_tensor_diagnostic() - - -def _test_func(): - ans = [] - for i in range(10): - x = list() - x.append(i) - def func(): - return x - ans.append(func) - return ans