mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
* Fix torch.nn.Embedding error for torch below 1.8.0 * Changes to fbank computation, use lilcom chunky writer * Add min in q,k,v of attention * Remove learnable offset, use relu instead. * Experiments based on SpecAugment change * Merge specaug change from Mingshuang. * Use much more aggressive SpecAug setup * Fix to num_feature_masks bug I introduced; reduce max_frames_mask_fraction 0.4->0.3 * Change p=0.5->0.9, mask_fraction 0.3->0.2 * Change p=0.9 to p=0.8 in SpecAug * Fix num_time_masks code; revert 0.8 to 0.9 * Change max_frames from 0.2 to 0.15 * Remove ReLU in attention * Adding diagnostics code... * Refactor/simplify ConformerEncoder * First version of rand-combine iterated-training-like idea. * Improvements to diagnostics (RE those with 1 dim * Add pelu to this good-performing setup.. * Small bug fixes/imports * Add baseline for the PeLU expt, keeping only the small normalization-related changes. * pelu_base->expscale, add 2xExpScale in subsampling, and in feedforward units. * Double learning rate of exp-scale units * Combine ExpScale and swish for memory reduction * Add import * Fix backprop bug * Fix bug in diagnostics * Increase scale on Scale from 4 to 20 * Increase scale from 20 to 50. * Fix duplicate Swish; replace norm+swish with swish+exp-scale in convolution module * Reduce scale from 50 to 20 * Add deriv-balancing code * Double the threshold in brelu; slightly increase max_factor. * Fix exp dir * Convert swish nonlinearities to ReLU * Replace relu with swish-squared. * Restore ConvolutionModule to state before changes; change all Swish,Swish(Swish) to SwishOffset. * Replace norm on input layer with scale of 0.1. * Extensions to diagnostics code * Update diagnostics * Add BasicNorm module * Replace most normalizations with scales (still have norm in conv) * Change exp dir * Replace norm in ConvolutionModule with a scaling factor. * use nonzero threshold in DerivBalancer * Add min-abs-value 0.2 * Fix dirname * Change min-abs threshold from 0.2 to 0.5 * Scale up pos_bias_u and pos_bias_v before use. * Reduce max_factor to 0.01 * Fix q*scaling logic * Change max_factor in DerivBalancer from 0.025 to 0.01; fix scaling code. * init 1st conv module to smaller variance * Change how scales are applied; fix residual bug * Reduce min_abs from 0.5 to 0.2 * Introduce in_scale=0.5 for SwishExpScale * Fix scale from 0.5 to 2.0 as I really intended.. * Set scaling on SwishExpScale * Add identity pre_norm_final for diagnostics. * Add learnable post-scale for mha * Fix self.post-scale-mha * Another rework, use scales on linear/conv * Change dir name * Reduce initial scaling of modules * Bug-fix RE bias * Cosmetic change * Reduce initial_scale. * Replace ExpScaleRelu with DoubleSwish() * DoubleSwish fix * Use learnable scales for joiner and decoder * Add max-abs-value constraint in DerivBalancer * Add max-abs-value * Change dir name * Remove ExpScale in feedforward layes. * Reduce max-abs limit from 1000 to 100; introduce 2 DerivBalancer modules in conv layer. * Make DoubleSwish more memory efficient * Reduce constraints from deriv-balancer in ConvModule. * Add warmup mode * Remove max-positive constraint in deriv-balancing; add second DerivBalancer in conv module. * Add some extra info to diagnostics * Add deriv-balancer at output of embedding. * Add more stats. * Make epsilon in BasicNorm learnable, optionally. * Draft of 0mean changes.. * Rework of initialization * Fix typo * Remove dead code * Modifying initialization from normal->uniform; add initial_scale when initializing * bug fix re sqrt * Remove xscale from pos_embedding * Remove some dead code. * Cosmetic changes/renaming things * Start adding some files.. * Add more files.. * update decode.py file type * Add remaining files in pruned_transducer_stateless2 * Fix diagnostics-getting code * Scale down pruned loss in warmup mode * Reduce warmup scale on pruned loss form 0.1 to 0.01. * Remove scale_speed, make swish deriv more efficient. * Cosmetic changes to swish * Double warm_step * Fix bug with import * Change initial std from 0.05 to 0.025. * Set also scale for embedding to 0.025. * Remove logging code that broke with newer Lhotse; fix bug with pruned_loss * Add norm+balancer to VggSubsampling * Incorporate changes from master into pruned_transducer_stateless2. * Add max-abs=6, debugged version * Change 0.025,0.05 to 0.01 in initializations * Fix balancer code * Whitespace fix * Reduce initial pruned_loss scale from 0.01 to 0.0 * Increase warm_step (and valid_interval) * Change max-abs from 6 to 10 * Change how warmup works. * Add changes from master to decode.py, train.py * Simplify the warmup code; max_abs 10->6 * Make warmup work by scaling layer contributions; leave residual layer-drop * Fix bug * Fix test mode with random layer dropout * Add random-number-setting function in dataloader * Fix/patch how fix_random_seed() is imported. * Reduce layer-drop prob * Reduce layer-drop prob after warmup to 1 in 100 * Change power of lr-schedule from -0.5 to -0.333 * Increase model_warm_step to 4k * Change max-keep-prob to 0.95 * Refactoring and simplifying conformer and frontend * Rework conformer, remove some code. * Reduce 1st conv channels from 64 to 32 * Add another convolutional layer * Fix padding bug * Remove dropout in output layer * Reduce speed of some components * Initial refactoring to remove unnecessary vocab_size * Fix RE identity * Bug-fix * Add final dropout to conformer * Remove some un-used code * Replace nn.Linear with ScaledLinear in simple joiner * Make 2 projections.. * Reduce initial_speed * Use initial_speed=0.5 * Reduce initial_speed further from 0.5 to 0.25 * Reduce initial_speed from 0.5 to 0.25 * Change how warmup is applied. * Bug fix to warmup_scale * Fix test-mode * Remove final dropout * Make layer dropout rate 0.075, was 0.1. * First draft of model rework * Various bug fixes * Change learning speed of simple_lm_proj * Revert transducer_stateless/ to state in upstream/master * Fix to joiner to allow different dims * Some cleanups * Make training more efficient, avoid redoing some projections. * Change how warm-step is set * First draft of new approach to learning rates + init * Some fixes.. * Change initialization to 0.25 * Fix type of parameter * Fix weight decay formula by adding 1/1-beta * Fix weight decay formula by adding 1/1-beta * Fix checkpoint-writing * Fix to reading scheudler from optim * Simplified optimizer, rework somet things.. * Reduce model_warm_step from 4k to 3k * Fix bug in lambda * Bug-fix RE sign of target_rms * Changing initial_speed from 0.25 to 01 * Change some defaults in LR-setting rule. * Remove initial_speed * Set new scheduler * Change exponential part of lrate to be epoch based * Fix bug * Set 2n rule.. * Implement 2o schedule * Make lrate rule more symmetric * Implement 2p version of learning rate schedule. * Refactor how learning rate is set. * Fix import * Modify init (#301) * update icefall/__init__.py to import more common functions. * update icefall/__init__.py * make imports style consistent. * exclude black check for icefall/__init__.py in pyproject.toml. * Minor fixes for logging (#296) * Minor fixes for logging * Minor fix * Fix dir names * Modify beam search to be efficient with current joienr * Fix adding learning rate to tensorboard * Fix docs in optim.py * Support mix precision training on the reworked model (#305) * Add mix precision support * Minor fixes * Minor fixes * Minor fixes * Tedlium3 pruned transducer stateless (#261) * update tedlium3-pruned-transducer-stateless-codes * update README.md * update README.md * add fast beam search for decoding * do a change for RESULTS.md * do a change for RESULTS.md * do a fix * do some changes for pruned RNN-T * Add mix precision support * Minor fixes * Minor fixes * Updating RESULTS.md; fix in beam_search.py * Fix rebase * Code style check for librispeech pruned transducer stateless2 (#308) * Update results for tedlium3 pruned RNN-T (#307) * Update README.md * Fix CI errors. (#310) * Add more results * Fix tensorboard log location * Add one more epoch of full expt * fix comments * Add results for mixed precision with max-duration 300 * Changes for pretrained.py (tedlium3 pruned RNN-T) (#311) * GigaSpeech recipe (#120) * initial commit * support download, data prep, and fbank * on-the-fly feature extraction by default * support BPE based lang * support HLG for BPE * small fix * small fix * chunked feature extraction by default * Compute features for GigaSpeech by splitting the manifest. * Fixes after review. * Split manifests into 2000 pieces. * set audio duration mismatch tolerance to 0.01 * small fix * add conformer training recipe * Add conformer.py without pre-commit checking * lazy loading and use SingleCutSampler * DynamicBucketingSampler * use KaldifeatFbank to compute fbank for musan * use pretrained language model and lexicon * use 3gram to decode, 4gram to rescore * Add decode.py * Update .flake8 * Delete compute_fbank_gigaspeech.py * Use BucketingSampler for valid and test dataloader * Update params in train.py * Use bpe_500 * update params in decode.py * Decrease num_paths while CUDA OOM * Added README * Update RESULTS * black * Decrease num_paths while CUDA OOM * Decode with post-processing * Update results * Remove lazy_load option * Use default `storage_type` * Keep the original tolerance * Use split-lazy * black * Update pretrained model Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> * Add LG decoding (#277) * Add LG decoding * Add log weight pushing * Minor fixes * Support computing RNN-T loss with torchaudio (#316) * Support modified beam search decoding for streaming inference with Emformer model. * Formatted imports. * Update results for torchaudio RNN-T. (#322) * Fixed streaming decoding codes for emformer model. * Fixed docs. * Sorted imports for transducer_emformer/streaming_feature_extractor.py * Minor fix for transducer_emformer/streaming_feature_extractor.py Co-authored-by: pkufool <wkang@pku.org.cn> Co-authored-by: Daniel Povey <dpovey@gmail.com> Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com> Co-authored-by: Guo Liyong <guonwpu@qq.com> Co-authored-by: Wang, Guanbo <wgb14@outlook.com>
703 lines
25 KiB
Python
703 lines
25 KiB
Python
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import collections
|
|
from itertools import repeat
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, collections.Iterable):
|
|
return x
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
|
|
_single = _ntuple(1)
|
|
_pair = _ntuple(2)
|
|
|
|
|
|
class ActivationBalancerFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
x: Tensor,
|
|
channel_dim: int,
|
|
min_positive: float, # e.g. 0.05
|
|
max_positive: float, # e.g. 0.95
|
|
max_factor: float, # e.g. 0.01
|
|
min_abs: float, # e.g. 0.2
|
|
max_abs: float, # e.g. 100.0
|
|
) -> Tensor:
|
|
if x.requires_grad:
|
|
if channel_dim < 0:
|
|
channel_dim += x.ndim
|
|
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
|
xgt0 = x > 0
|
|
proportion_positive = torch.mean(
|
|
xgt0.to(x.dtype), dim=sum_dims, keepdim=True
|
|
)
|
|
factor1 = (
|
|
(min_positive - proportion_positive).relu()
|
|
* (max_factor / min_positive)
|
|
if min_positive != 0.0
|
|
else 0.0
|
|
)
|
|
factor2 = (
|
|
(proportion_positive - max_positive).relu()
|
|
* (max_factor / (max_positive - 1.0))
|
|
if max_positive != 1.0
|
|
else 0.0
|
|
)
|
|
factor = factor1 + factor2
|
|
if isinstance(factor, float):
|
|
factor = torch.zeros_like(proportion_positive)
|
|
|
|
mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True)
|
|
below_threshold = mean_abs < min_abs
|
|
above_threshold = mean_abs > max_abs
|
|
|
|
ctx.save_for_backward(
|
|
factor, xgt0, below_threshold, above_threshold
|
|
)
|
|
ctx.max_factor = max_factor
|
|
ctx.sum_dims = sum_dims
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(
|
|
ctx, x_grad: Tensor
|
|
) -> Tuple[Tensor, None, None, None, None, None, None]:
|
|
factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors
|
|
dtype = x_grad.dtype
|
|
scale_factor = (
|
|
(below_threshold.to(dtype) - above_threshold.to(dtype))
|
|
* (xgt0.to(dtype) - 0.5)
|
|
* (ctx.max_factor * 2.0)
|
|
)
|
|
|
|
neg_delta_grad = x_grad.abs() * (factor + scale_factor)
|
|
return x_grad - neg_delta_grad, None, None, None, None, None, None
|
|
|
|
|
|
class BasicNorm(torch.nn.Module):
|
|
"""
|
|
This is intended to be a simpler, and hopefully cheaper, replacement for
|
|
LayerNorm. The observation this is based on, is that Transformer-type
|
|
networks, especially with pre-norm, sometimes seem to set one of the
|
|
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
|
the LayerNorm because the output magnitude is then not strongly dependent
|
|
on the other (useful) features. Presumably the weight and bias of the
|
|
LayerNorm are required to allow it to do this.
|
|
|
|
So the idea is to introduce this large constant value as an explicit
|
|
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
|
doesn't have to do this trick. We make the "eps" learnable.
|
|
|
|
Args:
|
|
num_channels: the number of channels, e.g. 512.
|
|
channel_dim: the axis/dimension corresponding to the channel,
|
|
interprted as an offset from the input's ndim if negative.
|
|
shis is NOT the num_channels; it should typically be one of
|
|
{-2, -1, 0, 1, 2, 3}.
|
|
eps: the initial "epsilon" that we add as ballast in:
|
|
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
|
Note: our epsilon is actually large, but we keep the name
|
|
to indicate the connection with conventional LayerNorm.
|
|
learn_eps: if true, we learn epsilon; if false, we keep it
|
|
at the initial value.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_channels: int,
|
|
channel_dim: int = -1, # CAUTION: see documentation.
|
|
eps: float = 0.25,
|
|
learn_eps: bool = True,
|
|
) -> None:
|
|
super(BasicNorm, self).__init__()
|
|
self.num_channels = num_channels
|
|
self.channel_dim = channel_dim
|
|
if learn_eps:
|
|
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
|
else:
|
|
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
assert x.shape[self.channel_dim] == self.num_channels
|
|
scales = (
|
|
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
|
+ self.eps.exp()
|
|
) ** -0.5
|
|
return x * scales
|
|
|
|
|
|
class ScaledLinear(nn.Linear):
|
|
"""
|
|
A modified version of nn.Linear where the parameters are scaled before
|
|
use, via:
|
|
weight = self.weight * self.weight_scale.exp()
|
|
bias = self.bias * self.bias_scale.exp()
|
|
|
|
Args:
|
|
Accepts the standard args and kwargs that nn.Linear accepts
|
|
e.g. in_features, out_features, bias=False.
|
|
|
|
initial_scale: you can override this if you want to increase
|
|
or decrease the initial magnitude of the module's output
|
|
(affects the initialization of weight_scale and bias_scale).
|
|
Another option, if you want to do something like this, is
|
|
to re-initialize the parameters.
|
|
initial_speed: this affects how fast the parameter will
|
|
learn near the start of training; you can set it to a
|
|
value less than one if you suspect that a module
|
|
is contributing to instability near the start of training.
|
|
Nnote: regardless of the use of this option, it's best to
|
|
use schedulers like Noam that have a warm-up period.
|
|
Alternatively you can set it to more than 1 if you want it to
|
|
initially train faster. Must be greater than 0.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
initial_scale: float = 1.0,
|
|
initial_speed: float = 1.0,
|
|
**kwargs
|
|
):
|
|
super(ScaledLinear, self).__init__(*args, **kwargs)
|
|
initial_scale = torch.tensor(initial_scale).log()
|
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
if self.bias is not None:
|
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
else:
|
|
self.register_parameter("bias_scale", None)
|
|
|
|
self._reset_parameters(
|
|
initial_speed
|
|
) # Overrides the reset_parameters in nn.Linear
|
|
|
|
def _reset_parameters(self, initial_speed: float):
|
|
std = 0.1 / initial_speed
|
|
a = (3 ** 0.5) * std
|
|
nn.init.uniform_(self.weight, -a, a)
|
|
if self.bias is not None:
|
|
nn.init.constant_(self.bias, 0.0)
|
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
|
with torch.no_grad():
|
|
self.weight_scale += torch.tensor(scale / std).log()
|
|
|
|
def get_weight(self):
|
|
return self.weight * self.weight_scale.exp()
|
|
|
|
def get_bias(self):
|
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return torch.nn.functional.linear(
|
|
input, self.get_weight(), self.get_bias()
|
|
)
|
|
|
|
|
|
class ScaledConv1d(nn.Conv1d):
|
|
# See docs for ScaledLinear
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
initial_scale: float = 1.0,
|
|
initial_speed: float = 1.0,
|
|
**kwargs
|
|
):
|
|
super(ScaledConv1d, self).__init__(*args, **kwargs)
|
|
initial_scale = torch.tensor(initial_scale).log()
|
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
if self.bias is not None:
|
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
else:
|
|
self.register_parameter("bias_scale", None)
|
|
self._reset_parameters(
|
|
initial_speed
|
|
) # Overrides the reset_parameters in base class
|
|
|
|
def _reset_parameters(self, initial_speed: float):
|
|
std = 0.1 / initial_speed
|
|
a = (3 ** 0.5) * std
|
|
nn.init.uniform_(self.weight, -a, a)
|
|
if self.bias is not None:
|
|
nn.init.constant_(self.bias, 0.0)
|
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
|
with torch.no_grad():
|
|
self.weight_scale += torch.tensor(scale / std).log()
|
|
|
|
def get_weight(self):
|
|
return self.weight * self.weight_scale.exp()
|
|
|
|
def get_bias(self):
|
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
F = torch.nn.functional
|
|
if self.padding_mode != "zeros":
|
|
return F.conv1d(
|
|
F.pad(
|
|
input,
|
|
self._reversed_padding_repeated_twice,
|
|
mode=self.padding_mode,
|
|
),
|
|
self.get_weight(),
|
|
self.get_bias(),
|
|
self.stride,
|
|
_single(0),
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
return F.conv1d(
|
|
input,
|
|
self.get_weight(),
|
|
self.get_bias(),
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
|
|
|
|
class ScaledConv2d(nn.Conv2d):
|
|
# See docs for ScaledLinear
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
initial_scale: float = 1.0,
|
|
initial_speed: float = 1.0,
|
|
**kwargs
|
|
):
|
|
super(ScaledConv2d, self).__init__(*args, **kwargs)
|
|
initial_scale = torch.tensor(initial_scale).log()
|
|
self.weight_scale = nn.Parameter(initial_scale.clone().detach())
|
|
if self.bias is not None:
|
|
self.bias_scale = nn.Parameter(initial_scale.clone().detach())
|
|
else:
|
|
self.register_parameter("bias_scale", None)
|
|
self._reset_parameters(
|
|
initial_speed
|
|
) # Overrides the reset_parameters in base class
|
|
|
|
def _reset_parameters(self, initial_speed: float):
|
|
std = 0.1 / initial_speed
|
|
a = (3 ** 0.5) * std
|
|
nn.init.uniform_(self.weight, -a, a)
|
|
if self.bias is not None:
|
|
nn.init.constant_(self.bias, 0.0)
|
|
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
|
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
|
with torch.no_grad():
|
|
self.weight_scale += torch.tensor(scale / std).log()
|
|
|
|
def get_weight(self):
|
|
return self.weight * self.weight_scale.exp()
|
|
|
|
def get_bias(self):
|
|
return None if self.bias is None else self.bias * self.bias_scale.exp()
|
|
|
|
def _conv_forward(self, input, weight):
|
|
F = torch.nn.functional
|
|
if self.padding_mode != "zeros":
|
|
return F.conv2d(
|
|
F.pad(
|
|
input,
|
|
self._reversed_padding_repeated_twice,
|
|
mode=self.padding_mode,
|
|
),
|
|
weight,
|
|
self.get_bias(),
|
|
self.stride,
|
|
_pair(0),
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
return F.conv2d(
|
|
input,
|
|
weight,
|
|
self.get_bias(),
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
return self._conv_forward(input, self.get_weight())
|
|
|
|
|
|
class ActivationBalancer(torch.nn.Module):
|
|
"""
|
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
|
each channel, that it is positive at least a proportion `threshold` of the
|
|
time. It does this by multiplying negative derivative values by up to
|
|
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
|
interpolated from 1 at the threshold to those extremal values when none
|
|
of the inputs are positive.
|
|
|
|
|
|
Args:
|
|
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
|
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
|
min_positive: the minimum, per channel, of the proportion of the time
|
|
that (x > 0), below which we start to modify the derivatives.
|
|
max_positive: the maximum, per channel, of the proportion of the time
|
|
that (x > 0), above which we start to modify the derivatives.
|
|
max_factor: the maximum factor by which we modify the derivatives for
|
|
either the sign constraint or the magnitude constraint;
|
|
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
|
values in the range [0.98..1.02].
|
|
min_abs: the minimum average-absolute-value per channel, which
|
|
we allow, before we start to modify the derivatives to prevent
|
|
this.
|
|
max_abs: the maximum average-absolute-value per channel, which
|
|
we allow, before we start to modify the derivatives to prevent
|
|
this.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channel_dim: int,
|
|
min_positive: float = 0.05,
|
|
max_positive: float = 0.95,
|
|
max_factor: float = 0.01,
|
|
min_abs: float = 0.2,
|
|
max_abs: float = 100.0,
|
|
):
|
|
super(ActivationBalancer, self).__init__()
|
|
self.channel_dim = channel_dim
|
|
self.min_positive = min_positive
|
|
self.max_positive = max_positive
|
|
self.max_factor = max_factor
|
|
self.min_abs = min_abs
|
|
self.max_abs = max_abs
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return ActivationBalancerFunction.apply(
|
|
x,
|
|
self.channel_dim,
|
|
self.min_positive,
|
|
self.max_positive,
|
|
self.max_factor,
|
|
self.min_abs,
|
|
self.max_abs,
|
|
)
|
|
|
|
|
|
class DoubleSwishFunction(torch.autograd.Function):
|
|
"""
|
|
double_swish(x) = x * torch.sigmoid(x-1)
|
|
This is a definition, originally motivated by its close numerical
|
|
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
|
|
|
Memory-efficient derivative computation:
|
|
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
|
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
|
Now, s'(x) = s(x) * (1-s(x)).
|
|
double_swish'(x) = x * s'(x) + s(x).
|
|
= x * s(x) * (1-s(x)) + s(x).
|
|
= double_swish(x) * (1-s(x)) + s(x)
|
|
... so we just need to remember s(x) but not x itself.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x: Tensor) -> Tensor:
|
|
x = x.detach()
|
|
s = torch.sigmoid(x - 1.0)
|
|
y = x * s
|
|
ctx.save_for_backward(s, y)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
|
s, y = ctx.saved_tensors
|
|
return (y * (1 - s) + s) * y_grad
|
|
|
|
|
|
class DoubleSwish(torch.nn.Module):
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
|
that we approximate closely with x * sigmoid(x-1).
|
|
"""
|
|
return DoubleSwishFunction.apply(x)
|
|
|
|
|
|
class ScaledEmbedding(nn.Module):
|
|
r"""This is a modified version of nn.Embedding that introduces a learnable scale
|
|
on the parameters. Note: due to how we initialize it, it's best used with
|
|
schedulers like Noam that have a warmup period.
|
|
|
|
It is a simple lookup table that stores embeddings of a fixed dictionary and size.
|
|
|
|
This module is often used to store word embeddings and retrieve them using indices.
|
|
The input to the module is a list of indices, and the output is the corresponding
|
|
word embeddings.
|
|
|
|
Args:
|
|
num_embeddings (int): size of the dictionary of embeddings
|
|
embedding_dim (int): the size of each embedding vector
|
|
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
|
(initialized to zeros) whenever it encounters the index.
|
|
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
|
is renormalized to have norm :attr:`max_norm`.
|
|
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
|
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
|
the words in the mini-batch. Default ``False``.
|
|
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
|
See Notes for more details regarding sparse gradients.
|
|
|
|
initial_speed (float, optional): This affects how fast the parameter will
|
|
learn near the start of training; you can set it to a value less than
|
|
one if you suspect that a module is contributing to instability near
|
|
the start of training. Nnote: regardless of the use of this option,
|
|
it's best to use schedulers like Noam that have a warm-up period.
|
|
Alternatively you can set it to more than 1 if you want it to
|
|
initially train faster. Must be greater than 0.
|
|
|
|
|
|
Attributes:
|
|
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
|
initialized from :math:`\mathcal{N}(0, 1)`
|
|
|
|
Shape:
|
|
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
|
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
|
|
|
.. note::
|
|
Keep in mind that only a limited number of optimizers support
|
|
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
|
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
|
|
|
.. note::
|
|
With :attr:`padding_idx` set, the embedding vector at
|
|
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
|
vector can be modified afterwards, e.g., using a customized
|
|
initialization method, and thus changing the vector used to pad the
|
|
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
|
is always zero.
|
|
|
|
Examples::
|
|
|
|
>>> # an Embedding module containing 10 tensors of size 3
|
|
>>> embedding = nn.Embedding(10, 3)
|
|
>>> # a batch of 2 samples of 4 indices each
|
|
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
|
>>> embedding(input)
|
|
tensor([[[-0.0251, -1.6902, 0.7172],
|
|
[-0.6431, 0.0748, 0.6969],
|
|
[ 1.4970, 1.3448, -0.9685],
|
|
[-0.3677, -2.7265, -0.1685]],
|
|
|
|
[[ 1.4970, 1.3448, -0.9685],
|
|
[ 0.4362, -0.4004, 0.9400],
|
|
[-0.6431, 0.0748, 0.6969],
|
|
[ 0.9124, -2.3616, 1.1151]]])
|
|
|
|
|
|
>>> # example with padding_idx
|
|
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
|
>>> input = torch.LongTensor([[0,2,0,5]])
|
|
>>> embedding(input)
|
|
tensor([[[ 0.0000, 0.0000, 0.0000],
|
|
[ 0.1535, -2.0309, 0.9315],
|
|
[ 0.0000, 0.0000, 0.0000],
|
|
[-0.1655, 0.9897, 0.0635]]])
|
|
|
|
"""
|
|
__constants__ = [
|
|
"num_embeddings",
|
|
"embedding_dim",
|
|
"padding_idx",
|
|
"scale_grad_by_freq",
|
|
"sparse",
|
|
]
|
|
|
|
num_embeddings: int
|
|
embedding_dim: int
|
|
padding_idx: int
|
|
scale_grad_by_freq: bool
|
|
weight: Tensor
|
|
sparse: bool
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
padding_idx: Optional[int] = None,
|
|
scale_grad_by_freq: bool = False,
|
|
sparse: bool = False,
|
|
initial_speed: float = 1.0,
|
|
) -> None:
|
|
super(ScaledEmbedding, self).__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
if padding_idx is not None:
|
|
if padding_idx > 0:
|
|
assert (
|
|
padding_idx < self.num_embeddings
|
|
), "Padding_idx must be within num_embeddings"
|
|
elif padding_idx < 0:
|
|
assert (
|
|
padding_idx >= -self.num_embeddings
|
|
), "Padding_idx must be within num_embeddings"
|
|
padding_idx = self.num_embeddings + padding_idx
|
|
self.padding_idx = padding_idx
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
|
|
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
|
self.sparse = sparse
|
|
|
|
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
|
self.reset_parameters(initial_speed)
|
|
|
|
def reset_parameters(self, initial_speed: float = 1.0) -> None:
|
|
std = 0.1 / initial_speed
|
|
nn.init.normal_(self.weight, std=std)
|
|
nn.init.constant_(self.scale, torch.tensor(1.0 / std).log())
|
|
|
|
if self.padding_idx is not None:
|
|
with torch.no_grad():
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
def forward(self, input: Tensor) -> Tensor:
|
|
F = torch.nn.functional
|
|
scale = self.scale.exp()
|
|
if input.numel() < self.num_embeddings:
|
|
return (
|
|
F.embedding(
|
|
input,
|
|
self.weight,
|
|
self.padding_idx,
|
|
None,
|
|
2.0, # None, 2.0 relate to normalization
|
|
self.scale_grad_by_freq,
|
|
self.sparse,
|
|
)
|
|
* scale
|
|
)
|
|
else:
|
|
return F.embedding(
|
|
input,
|
|
self.weight * scale,
|
|
self.padding_idx,
|
|
None,
|
|
2.0, # None, 2.0 relates to normalization
|
|
self.scale_grad_by_freq,
|
|
self.sparse,
|
|
)
|
|
|
|
def extra_repr(self) -> str:
|
|
s = "{num_embeddings}, {embedding_dim}, scale={scale}"
|
|
if self.padding_idx is not None:
|
|
s += ", padding_idx={padding_idx}"
|
|
if self.scale_grad_by_freq is not False:
|
|
s += ", scale_grad_by_freq={scale_grad_by_freq}"
|
|
if self.sparse is not False:
|
|
s += ", sparse=True"
|
|
return s.format(**self.__dict__)
|
|
|
|
|
|
def _test_activation_balancer_sign():
|
|
probs = torch.arange(0, 1, 0.01)
|
|
N = 1000
|
|
x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))
|
|
x = x.detach()
|
|
x.requires_grad = True
|
|
m = ActivationBalancer(
|
|
channel_dim=0,
|
|
min_positive=0.05,
|
|
max_positive=0.95,
|
|
max_factor=0.2,
|
|
min_abs=0.0,
|
|
)
|
|
|
|
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
|
|
|
y = m(x)
|
|
y.backward(gradient=y_grad)
|
|
print("_test_activation_balancer_sign: x = ", x)
|
|
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
|
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
|
|
|
|
|
def _test_activation_balancer_magnitude():
|
|
magnitudes = torch.arange(0, 1, 0.01)
|
|
N = 1000
|
|
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
|
-1
|
|
)
|
|
x = x.detach()
|
|
x.requires_grad = True
|
|
m = ActivationBalancer(
|
|
channel_dim=0,
|
|
min_positive=0.0,
|
|
max_positive=1.0,
|
|
max_factor=0.2,
|
|
min_abs=0.2,
|
|
max_abs=0.8,
|
|
)
|
|
|
|
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
|
|
|
y = m(x)
|
|
y.backward(gradient=y_grad)
|
|
print("_test_activation_balancer_magnitude: x = ", x)
|
|
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
|
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
|
|
|
|
|
def _test_basic_norm():
|
|
num_channels = 128
|
|
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
|
|
|
x = torch.randn(500, num_channels)
|
|
|
|
y = m(x)
|
|
|
|
assert y.shape == x.shape
|
|
x_rms = (x ** 2).mean().sqrt()
|
|
y_rms = (y ** 2).mean().sqrt()
|
|
print("x rms = ", x_rms)
|
|
print("y rms = ", y_rms)
|
|
assert y_rms < x_rms
|
|
assert y_rms > 0.5 * x_rms
|
|
|
|
|
|
def _test_double_swish_deriv():
|
|
x = torch.randn(10, 12, dtype=torch.double) * 0.5
|
|
x.requires_grad = True
|
|
m = DoubleSwish()
|
|
torch.autograd.gradcheck(m, x)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
_test_activation_balancer_sign()
|
|
_test_activation_balancer_magnitude()
|
|
_test_basic_norm()
|
|
_test_double_swish_deriv()
|