add ScaledLSTM

This commit is contained in:
yaozengwei 2022-07-16 22:47:05 +08:00
parent 0fcdd15fec
commit 9165de5f57

View File

@ -1,4 +1,4 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -21,7 +21,8 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch import Tensor, _VF
from torch.nn.utils.rnn import PackedSequence
def _ntuple(n):
@ -154,7 +155,7 @@ class BasicNorm(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
torch.mean(x**2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
) ** -0.5
return x * scales
@ -207,12 +208,12 @@ class ScaledLinear(nn.Linear):
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
a = (3**0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
scale = fan_in**-0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
@ -256,12 +257,12 @@ class ScaledConv1d(nn.Conv1d):
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
a = (3**0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
scale = fan_in**-0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
@ -325,12 +326,12 @@ class ScaledConv2d(nn.Conv2d):
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
a = (3**0.5) * std
nn.init.uniform_(self.weight, -a, a)
if self.bias is not None:
nn.init.constant_(self.bias, 0.0)
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
scale = fan_in**-0.5 # 1/sqrt(fan_in)
with torch.no_grad():
self.weight_scale += torch.tensor(scale / std).log()
@ -376,6 +377,137 @@ class ScaledConv2d(nn.Conv2d):
return self._conv_forward(input, self.get_weight())
class ScaledLSTM(nn.LSTM):
# See docs for ScaledLinear.
# This class implements single-layer LSTM with scaling mechanism, using `torch._VF.lstm`
# Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
# Hardcode num_layers=1 and proj_size=0 here
super(ScaledLSTM, self).__init__(
*args, num_layers=1, proj_size=0, **kwargs
)
initial_scale = torch.tensor(initial_scale).log()
self._scales_names = []
self._scales = []
for name in self._flat_weights_names:
scale_name = name + "_scale"
self._scales_names.append(scale_name)
param = nn.Parameter(initial_scale.clone().detach())
setattr(self, scale_name, param)
self._scales.append(param)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3**0.5) * std
fan_in = self.input_size
scale = fan_in**-0.5
v = scale / std
for idx, name in enumerate(self._flat_weights_names):
if "weight" in name:
nn.init.uniform_(self._flat_weights[idx], -a, a)
with torch.no_grad():
self._scales[idx] += torch.tensor(v).log()
elif "bias" in name:
nn.init.constant_(self._flat_weights[idx], 0.0)
def get_flat_weights(self):
flat_weights = []
for idx in range(len(self._flat_weights_names)):
flat_weights.append(
self._flat_weights[idx] * self._scales[idx].exp()
)
return flat_weights
def forward(self, input, hx=None):
# This function is copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
# The only change is for calling `_VF.lstm()`:
# self._flat_weights -> self.get_flat_weights()
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
else:
batch_sizes = None
max_batch_size = (
input.size(0) if self.batch_first else input.size(1)
)
sorted_indices = None
unsorted_indices = None
if hx is None:
num_directions = 2 if self.bidirectional else 1
real_hidden_size = (
self.proj_size if self.proj_size > 0 else self.hidden_size
)
h_zeros = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
real_hidden_size,
dtype=input.dtype,
device=input.device,
)
c_zeros = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
hx = (h_zeros, c_zeros)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = _VF.lstm(
input,
hx,
self.get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,
self.training,
self.bidirectional,
self.batch_first,
)
else:
result = _VF.lstm(
input,
batch_sizes,
hx,
self.get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,
self.training,
self.bidirectional,
)
output = result[0]
hidden = result[1:]
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
return output, self.permute_hidden(hidden, unsorted_indices)
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
@ -711,8 +843,8 @@ def _test_basic_norm():
y = m(x)
assert y.shape == x.shape
x_rms = (x ** 2).mean().sqrt()
y_rms = (y ** 2).mean().sqrt()
x_rms = (x**2).mean().sqrt()
y_rms = (y**2).mean().sqrt()
print("x rms = ", x_rms)
print("y rms = ", y_rms)
assert y_rms < x_rms
@ -726,8 +858,22 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x)
def _test_scaled_lstm():
N, L = 2, 30
dim_in, dim_hidden = 10, 20
m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True)
x = torch.randn(L, N, dim_in)
h0 = torch.randn(1, N, dim_hidden)
c0 = torch.randn(1, N, dim_hidden)
y, (h, c) = m(x, (h0, c0))
assert y.shape == (L, N, dim_hidden)
assert h.shape == (1, N, dim_hidden)
assert c.shape == (1, N, dim_hidden)
if __name__ == "__main__":
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()
_test_scaled_lstm()