add ScaledLSTM
This commit is contained in:
parent
0fcdd15fec
commit
9165de5f57
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -21,7 +21,8 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch import Tensor, _VF
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
@ -154,7 +155,7 @@ class BasicNorm(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
|
||||
torch.mean(x**2, dim=self.channel_dim, keepdim=True)
|
||||
+ self.eps.exp()
|
||||
) ** -0.5
|
||||
return x * scales
|
||||
@ -207,12 +208,12 @@ class ScaledLinear(nn.Linear):
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
a = (3**0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
@ -256,12 +257,12 @@ class ScaledConv1d(nn.Conv1d):
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
a = (3**0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
@ -325,12 +326,12 @@ class ScaledConv2d(nn.Conv2d):
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3 ** 0.5) * std
|
||||
a = (3**0.5) * std
|
||||
nn.init.uniform_(self.weight, -a, a)
|
||||
if self.bias is not None:
|
||||
nn.init.constant_(self.bias, 0.0)
|
||||
fan_in = self.weight.shape[1] * self.weight[0][0].numel()
|
||||
scale = fan_in ** -0.5 # 1/sqrt(fan_in)
|
||||
scale = fan_in**-0.5 # 1/sqrt(fan_in)
|
||||
with torch.no_grad():
|
||||
self.weight_scale += torch.tensor(scale / std).log()
|
||||
|
||||
@ -376,6 +377,137 @@ class ScaledConv2d(nn.Conv2d):
|
||||
return self._conv_forward(input, self.get_weight())
|
||||
|
||||
|
||||
class ScaledLSTM(nn.LSTM):
|
||||
# See docs for ScaledLinear.
|
||||
# This class implements single-layer LSTM with scaling mechanism, using `torch._VF.lstm`
|
||||
# Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
initial_scale: float = 1.0,
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
# Hardcode num_layers=1 and proj_size=0 here
|
||||
super(ScaledLSTM, self).__init__(
|
||||
*args, num_layers=1, proj_size=0, **kwargs
|
||||
)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self._scales_names = []
|
||||
self._scales = []
|
||||
for name in self._flat_weights_names:
|
||||
scale_name = name + "_scale"
|
||||
self._scales_names.append(scale_name)
|
||||
param = nn.Parameter(initial_scale.clone().detach())
|
||||
setattr(self, scale_name, param)
|
||||
self._scales.append(param)
|
||||
|
||||
self._reset_parameters(
|
||||
initial_speed
|
||||
) # Overrides the reset_parameters in base class
|
||||
|
||||
def _reset_parameters(self, initial_speed: float):
|
||||
std = 0.1 / initial_speed
|
||||
a = (3**0.5) * std
|
||||
fan_in = self.input_size
|
||||
scale = fan_in**-0.5
|
||||
v = scale / std
|
||||
for idx, name in enumerate(self._flat_weights_names):
|
||||
if "weight" in name:
|
||||
nn.init.uniform_(self._flat_weights[idx], -a, a)
|
||||
with torch.no_grad():
|
||||
self._scales[idx] += torch.tensor(v).log()
|
||||
elif "bias" in name:
|
||||
nn.init.constant_(self._flat_weights[idx], 0.0)
|
||||
|
||||
def get_flat_weights(self):
|
||||
flat_weights = []
|
||||
for idx in range(len(self._flat_weights_names)):
|
||||
flat_weights.append(
|
||||
self._flat_weights[idx] * self._scales[idx].exp()
|
||||
)
|
||||
return flat_weights
|
||||
|
||||
def forward(self, input, hx=None):
|
||||
# This function is copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
||||
# The only change is for calling `_VF.lstm()`:
|
||||
# self._flat_weights -> self.get_flat_weights()
|
||||
orig_input = input
|
||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
input, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
max_batch_size = batch_sizes[0]
|
||||
max_batch_size = int(max_batch_size)
|
||||
else:
|
||||
batch_sizes = None
|
||||
max_batch_size = (
|
||||
input.size(0) if self.batch_first else input.size(1)
|
||||
)
|
||||
sorted_indices = None
|
||||
unsorted_indices = None
|
||||
|
||||
if hx is None:
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
real_hidden_size = (
|
||||
self.proj_size if self.proj_size > 0 else self.hidden_size
|
||||
)
|
||||
h_zeros = torch.zeros(
|
||||
self.num_layers * num_directions,
|
||||
max_batch_size,
|
||||
real_hidden_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
c_zeros = torch.zeros(
|
||||
self.num_layers * num_directions,
|
||||
max_batch_size,
|
||||
self.hidden_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
hx = (h_zeros, c_zeros)
|
||||
else:
|
||||
# Each batch of the hidden state should match the input sequence that
|
||||
# the user believes he/she is passing in.
|
||||
hx = self.permute_hidden(hx, sorted_indices)
|
||||
|
||||
self.check_forward_args(input, hx, batch_sizes)
|
||||
if batch_sizes is None:
|
||||
result = _VF.lstm(
|
||||
input,
|
||||
hx,
|
||||
self.get_flat_weights(),
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
self.training,
|
||||
self.bidirectional,
|
||||
self.batch_first,
|
||||
)
|
||||
else:
|
||||
result = _VF.lstm(
|
||||
input,
|
||||
batch_sizes,
|
||||
hx,
|
||||
self.get_flat_weights(),
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
self.training,
|
||||
self.bidirectional,
|
||||
)
|
||||
output = result[0]
|
||||
hidden = result[1:]
|
||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
output_packed = PackedSequence(
|
||||
output, batch_sizes, sorted_indices, unsorted_indices
|
||||
)
|
||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||
else:
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||
@ -711,8 +843,8 @@ def _test_basic_norm():
|
||||
y = m(x)
|
||||
|
||||
assert y.shape == x.shape
|
||||
x_rms = (x ** 2).mean().sqrt()
|
||||
y_rms = (y ** 2).mean().sqrt()
|
||||
x_rms = (x**2).mean().sqrt()
|
||||
y_rms = (y**2).mean().sqrt()
|
||||
print("x rms = ", x_rms)
|
||||
print("y rms = ", y_rms)
|
||||
assert y_rms < x_rms
|
||||
@ -726,8 +858,22 @@ def _test_double_swish_deriv():
|
||||
torch.autograd.gradcheck(m, x)
|
||||
|
||||
|
||||
def _test_scaled_lstm():
|
||||
N, L = 2, 30
|
||||
dim_in, dim_hidden = 10, 20
|
||||
m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True)
|
||||
x = torch.randn(L, N, dim_in)
|
||||
h0 = torch.randn(1, N, dim_hidden)
|
||||
c0 = torch.randn(1, N, dim_hidden)
|
||||
y, (h, c) = m(x, (h0, c0))
|
||||
assert y.shape == (L, N, dim_hidden)
|
||||
assert h.shape == (1, N, dim_hidden)
|
||||
assert c.shape == (1, N, dim_hidden)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
_test_basic_norm()
|
||||
_test_double_swish_deriv()
|
||||
_test_scaled_lstm()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user