Apply layer normalization to the output of each gate in LSTM/GRU. (#139)

* Apply layer normalization to the output of each gate in LSTM.

* Apply layer normalization to the output of each gate in GRU.

* Add projection support to LayerNormLSTMCell.

* Add GPU tests.

* Use typeguard.check_argument_types() to validate type annotations.

* Add typeguard as a requirement.

* Minor fixes.

* Fix CI.

* Fix CI.

* Fix test failures for torch 1.8.0

* Fix errors.
This commit is contained in:
Fangjun Kuang 2021-12-07 18:38:03 +08:00 committed by GitHub
parent d1adc25338
commit 1aff64b708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1432 additions and 1 deletions

View File

@ -103,6 +103,9 @@ jobs:
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
cd ..
pytest -v -s ./transducer
- name: Run tests
if: startsWith(matrix.os, 'macos')
run: |
@ -113,6 +116,9 @@ jobs:
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
pytest -v -s ./test
# runt tests for conformer ctc
# run tests for conformer ctc
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
cd ..
pytest -v -s ./transducer

View File

@ -0,0 +1,659 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apply layer normalization to the output of each gate in LSTM/GRU.
This file uses
https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
as a reference.
"""
import math
from typing import List, Optional, Tuple, Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from typeguard import check_argument_types
class LayerNormLSTMCell(nn.Module):
"""This class places a `nn.LayerNorm` after the output of
each gate (right before the activation).
See the following paper for more details
'Improving RNN Transducer Modeling for End-to-End Speech Recognition'
https://arxiv.org/abs/1909.12415
Examples::
>>> cell = LayerNormLSTMCell(10, 20)
>>> input = torch.rand(5, 10)
>>> h0 = torch.rand(5, 20)
>>> c0 = torch.rand(5, 20)
>>> h1, c1 = cell(input, (h0, c0))
>>> output = h1
>>> h1.shape
torch.Size([5, 20])
>>> c1.shape
torch.Size([5, 20])
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
ln: Type[nn.Module] = nn.LayerNorm,
proj_size: int = 0,
device=None,
dtype=None,
):
"""
Args:
input_size:
The number of expected features in the input `x`. `x` should
be of shape (batch_size, input_size).
hidden_size:
The number of features in the hidden state `h` and `c`.
Both `h` and `c` are of shape (batch_size, hidden_size) when
proj_size is 0. If proj_size is not zero, the shape of `h`
is (batch_size, proj_size).
bias:
If ``False``, then the cell does not use bias weights
`bias_ih` and `bias_hh`.
ln:
Defaults to `nn.LayerNorm`. The output of all gates are processed
by `ln`. We pass it as an argument so that we can replace it
with `nn.Identity` at the testing time.
proj_size:
If not zero, it applies an affine transform to the output. In this
case, the shape of `h` is (batch_size, proj_size).
See https://arxiv.org/pdf/1402.1128.pdf
"""
assert check_argument_types()
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.proj_size = proj_size
if proj_size < 0:
raise ValueError(
f"proj_size {proj_size} should be a positive integer "
"or zero to disable projections"
)
if proj_size >= hidden_size:
raise ValueError(
f"proj_size {proj_size} has to be smaller "
f"than hidden_size {hidden_size}"
)
real_hidden_size = proj_size if proj_size > 0 else hidden_size
self.weight_ih = nn.Parameter(
torch.empty((4 * hidden_size, input_size), **factory_kwargs)
)
self.weight_hh = nn.Parameter(
torch.empty((4 * hidden_size, real_hidden_size), **factory_kwargs)
)
if bias:
self.bias_ih = nn.Parameter(
torch.empty(4 * hidden_size, **factory_kwargs)
)
self.bias_hh = nn.Parameter(
torch.empty(4 * hidden_size, **factory_kwargs)
)
else:
self.register_parameter("bias_ih", None)
self.register_parameter("bias_hh", None)
if proj_size > 0:
self.weight_hr = nn.Parameter(
torch.empty((proj_size, hidden_size), **factory_kwargs)
)
else:
self.register_parameter("weight_hr", None)
self.layernorm_i = ln(hidden_size)
self.layernorm_f = ln(hidden_size)
self.layernorm_cx = ln(hidden_size)
self.layernorm_cy = ln(hidden_size)
self.layernorm_o = ln(hidden_size)
self.reset_parameters()
def forward(
self,
input: torch.Tensor,
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input:
A 2-D tensor of shape (batch_size, input_size).
state:
If not ``None``, it contains the hidden state (h, c) for each
element in the batch. Both are of shape (batch_size, hidden_size)
if proj_size is 0. If proj_size is not zero, the shape of `h` is
(batch_size, proj_size).
If ``None``, it uses zeros for `h` and `c`.
Returns:
Return two tensors:
- `next_h`: It is of shape (batch_size, hidden_size) if proj_size
is 0, else (batch_size, proj_size), containing the next hidden
state for each element in the batch.
- `next_c`: It is of shape (batch_size, hidden_size) containing the
next cell state for each element in the batch.
"""
if state is None:
zeros = torch.zeros(
input.size(0),
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
state = (zeros, zeros)
hx, cx = state
gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear(
hx, self.weight_hh, self.bias_hh
)
in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1)
in_gate = self.layernorm_i(in_gate)
forget_gate = self.layernorm_f(forget_gate)
cell_gate = self.layernorm_cx(cell_gate)
out_gate = self.layernorm_o(out_gate)
in_gate = torch.sigmoid(in_gate)
forget_gate = torch.sigmoid(forget_gate)
cell_gate = torch.tanh(cell_gate)
out_gate = torch.sigmoid(out_gate)
cy = (forget_gate * cx) + (in_gate * cell_gate)
cy = self.layernorm_cy(cy)
hy = out_gate * torch.tanh(cy)
if self.weight_hr is not None:
hy = torch.matmul(hy, self.weight_hr.t())
return hy, cy
def extra_repr(self) -> str:
s = "{input_size}, {hidden_size}"
if "bias" in self.__dict__ and self.bias is not True:
s += ", bias={bias}"
return s.format(**self.__dict__)
def reset_parameters(self) -> None:
stdv = 1.0 / math.sqrt(self.hidden_size)
for name, weight in self.named_parameters():
if "layernorm" not in name:
nn.init.uniform_(weight, -stdv, stdv)
class LayerNormLSTMLayer(nn.Module):
"""
Examples::
>>> layer = LayerNormLSTMLayer(10, 20)
>>> input = torch.rand(2, 5, 10)
>>> h0 = torch.rand(2, 20)
>>> c0 = torch.rand(2, 20)
>>> output, (hn, cn) = layer(input, (h0, c0))
>>> output.shape
torch.Size([2, 5, 20])
>>> hn.shape
torch.Size([2, 20])
>>> cn.shape
torch.Size([2, 20])
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
ln: Type[nn.Module] = nn.LayerNorm,
proj_size: int = 0,
device=None,
dtype=None,
):
"""
See the args in LayerNormLSTMCell
"""
assert check_argument_types()
super().__init__()
self.cell = LayerNormLSTMCell(
input_size=input_size,
hidden_size=hidden_size,
bias=bias,
ln=ln,
proj_size=proj_size,
device=device,
dtype=dtype,
)
def forward(
self,
input: torch.Tensor,
state: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
input:
A 3-D tensor of shape (batch_size, seq_len, input_size).
Caution:
We use `batch_first=True` here.
state:
If not ``None``, it contains the hidden state (h, c) of this layer.
Both are of shape (batch_size, hidden_size) if proj_size is 0.
If proj_size is not 0, the shape of `h` is (batch_size, proj_size).
Note:
We did not annotate `state` with `Optional[Tuple[...]]` since
torchscript will complain.
Return:
- output, a tensor of shape (batch_size, seq_len, hidden_size)
- (next_h, next_c) containing the next hidden state
"""
inputs = input.unbind(1)
outputs = torch.jit.annotate(List[torch.Tensor], [])
for i in range(len(inputs)):
state = self.cell(inputs[i], state)
outputs.append(state[0])
return torch.stack(outputs, dim=1), state
class LayerNormLSTM(nn.Module):
"""
Examples::
>>> lstm = LayerNormLSTM(10, 20, 8)
>>> input = torch.rand(2, 3, 10)
>>> h0 = torch.rand(8, 2, 20).unbind(0)
>>> c0 = torch.rand(8, 2, 20).unbind(0)
>>> states = list(zip(h0, c0))
>>> output, next_states = lstm(input, states)
>>> output.shape
torch.Size([2, 3, 20])
>>> hn = torch.stack([s[0] for s in next_states])
>>> cn = torch.stack([s[1] for s in next_states])
>>> hn.shape
torch.Size([8, 2, 20])
>>> cn.shape
torch.Size([8, 2, 20])
"""
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bias: bool = True,
proj_size: int = 0,
ln: Type[nn.Module] = nn.LayerNorm,
device=None,
dtype=None,
):
"""
See the args in LayerNormLSTMLayer.
"""
assert check_argument_types()
super().__init__()
assert num_layers >= 1
factory_kwargs = dict(
hidden_size=hidden_size,
bias=bias,
ln=ln,
proj_size=proj_size,
device=device,
dtype=dtype,
)
first_layer = LayerNormLSTMLayer(
input_size=input_size, **factory_kwargs
)
layers = [first_layer]
for i in range(1, num_layers):
layers.append(
LayerNormLSTMLayer(
input_size=proj_size if proj_size > 0 else hidden_size,
**factory_kwargs,
)
)
self.layers = nn.ModuleList(layers)
self.num_layers = num_layers
def forward(
self,
input: torch.Tensor,
states: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Args:
input:
A 3-D tensor of shape (batch_size, seq_len, input_size).
Caution:
We use `batch_first=True` here.
states:
One state per layer. Each entry contains the hidden state (h, c)
for a layer. Both are of shape (batch_size, hidden_size) if
proj_size is 0. If proj_size is not 0, the shape of `h` is
(batch_size, proj_size).
Returns:
Return a tuple containing:
- output: A tensor of shape (batch_size, seq_len, hidden_size)
- List[(next_h, next_c)] containing the hidden states for all layers
"""
output_states = torch.jit.annotate(
List[Tuple[torch.Tensor, torch.Tensor]], []
)
output = input
for i, rnn_layer in enumerate(self.layers):
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
return output, output_states
class LayerNormGRUCell(nn.Module):
"""This class places a `nn.LayerNorm` after the output of
each gate (right before the activation).
See the following paper for more details
'Improving RNN Transducer Modeling for End-to-End Speech Recognition'
https://arxiv.org/abs/1909.12415
Examples::
>>> cell = LayerNormGRUCell(10, 20)
>>> input = torch.rand(2, 10)
>>> h0 = torch.rand(2, 20)
>>> hn = cell(input, h0)
>>> hn.shape
torch.Size([2, 20])
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
ln: Type[nn.Module] = nn.LayerNorm,
device=None,
dtype=None,
):
"""
Args:
input_size:
The number of expected features in the input `x`. `x` should
be of shape (batch_size, input_size).
hidden_size:
The number of features in the hidden state `h` and `c`.
Both `h` and `c` are of shape (batch_size, hidden_size).
bias:
If ``False``, then the cell does not use bias weights
`bias_ih` and `bias_hh`.
ln:
Defaults to `nn.LayerNorm`. The output of all gates are processed
by `ln`. We pass it as an argument so that we can replace it
with `nn.Identity` at the testing time.
"""
assert check_argument_types()
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = nn.Parameter(
torch.empty((3 * hidden_size, input_size), **factory_kwargs)
)
self.weight_hh = nn.Parameter(
torch.empty((3 * hidden_size, hidden_size), **factory_kwargs)
)
if bias:
self.bias_ih = nn.Parameter(
torch.empty(3 * hidden_size, **factory_kwargs)
)
self.bias_hh = nn.Parameter(
torch.empty(3 * hidden_size, **factory_kwargs)
)
else:
self.register_parameter("bias_ih", None)
self.register_parameter("bias_hh", None)
self.layernorm_r = ln(hidden_size)
self.layernorm_i = ln(hidden_size)
self.layernorm_n = ln(hidden_size)
self.reset_parameters()
def forward(
self,
input: torch.Tensor,
hx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
input:
A 2-D tensor of shape (batch_size, input_size) containing
input features.
hx:
If not `None`, it is a tensor of shape (batch_size, hidden_size)
containing the initial hidden state for each element in the batch.
If `None`, it uses zeros for the hidden state.
Returns:
Return a tensor of shape (batch_size, hidden_size) containing the
next hidden state for each element in the batch
"""
if hx is None:
hx = torch.zeros(
input.size(0),
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
i_r, i_i, i_n = F.linear(input, self.weight_ih, self.bias_ih).chunk(
chunks=3, dim=1
)
h_r, h_i, h_n = F.linear(hx, self.weight_hh, self.bias_hh).chunk(
chunks=3, dim=1
)
reset_gate = torch.sigmoid(self.layernorm_r(i_r + h_r))
input_gate = torch.sigmoid(self.layernorm_i(i_i + h_i))
new_gate = torch.tanh(self.layernorm_n(i_n + reset_gate * h_n))
# hy = (1 - input_gate) * new_gate + input_gate * hx
# = new_gate - input_gate * new_gate + input_gate * hx
# = new_gate + input_gate * (hx - new_gate)
hy = new_gate + input_gate * (hx - new_gate)
return hy
def extra_repr(self) -> str:
s = "{input_size}, {hidden_size}"
if "bias" in self.__dict__ and self.bias is not True:
s += ", bias={bias}"
return s.format(**self.__dict__)
def reset_parameters(self) -> None:
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)
class LayerNormGRULayer(nn.Module):
"""
Examples::
>>> layer = LayerNormGRULayer(10, 20)
>>> input = torch.rand(2, 3, 10)
>>> hx = torch.rand(2, 20)
>>> output, hn = layer(input, hx)
>>> output.shape
torch.Size([2, 3, 20])
>>> hn.shape
torch.Size([2, 20])
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
ln: Type[nn.Module] = nn.LayerNorm,
device=None,
dtype=None,
):
"""
See the args in LayerNormGRUCell
"""
assert check_argument_types()
super().__init__()
self.cell = LayerNormGRUCell(
input_size=input_size,
hidden_size=hidden_size,
bias=bias,
ln=ln,
device=device,
dtype=dtype,
)
def forward(
self,
input: torch.Tensor,
hx: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input:
A 3-D tensor of shape (batch_size, seq_len, input_size).
Caution:
We use `batch_first=True` here.
hx:
If not ``None``, it is a tensor of shape (batch_size, hidden_size)
containing the hidden state for each element in the batch.
Return:
- output, a tensor of shape (batch_size, seq_len, hidden_size)
- next_h, a tensor of shape (batch_size, hidden_size) containing the
final hidden state for each element in the batch.
"""
inputs = input.unbind(1)
outputs = torch.jit.annotate(List[torch.Tensor], [])
next_h = hx
for i in range(len(inputs)):
next_h = self.cell(inputs[i], next_h)
outputs.append(next_h)
return torch.stack(outputs, dim=1), next_h
class LayerNormGRU(nn.Module):
"""
Examples::
>>> input = torch.rand(2, 3, 10)
>>> h0 = torch.rand(8, 2, 20)
>>> states = h0.unbind(0)
>>> output, next_states = gru(input, states)
>>> output.shape
torch.Size([2, 3, 20])
>>> hn = torch.stack(next_states)
>>> hn.shape
torch.Size([8, 2, 20])
"""
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bias: bool = True,
ln: Type[nn.Module] = nn.LayerNorm,
device=None,
dtype=None,
):
"""
See the args in LayerNormGRULayer.
"""
assert check_argument_types()
super().__init__()
assert num_layers >= 1
factory_kwargs = dict(
hidden_size=hidden_size,
bias=bias,
ln=ln,
device=device,
dtype=dtype,
)
first_layer = LayerNormGRULayer(input_size=input_size, **factory_kwargs)
layers = [first_layer]
for i in range(1, num_layers):
layers.append(
LayerNormGRULayer(
input_size=hidden_size,
**factory_kwargs,
)
)
self.layers = nn.ModuleList(layers)
self.num_layers = num_layers
def forward(
self,
input: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
Args:
input:
A tensor of shape (batch_size, seq_len, input_size) containing
input features.
Caution:
We use `batch_first=True` here.
states:
One state per layer. Each entry contains the hidden state for each
element in the batch. Each hidden state is of shape
(batch_size, hidden_size)
Returns:
Return a tuple containing:
- output: A tensor of shape (batch_size, seq_len, hidden_size)
- List[next_state] containing the final hidden states for each
element in the batch
"""
output_states = torch.jit.annotate(List[torch.Tensor], [])
output = input
for i, rnn_layer in enumerate(self.layers):
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
return output, output_states

View File

@ -0,0 +1,765 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from transducer.rnn import (
LayerNormGRU,
LayerNormGRUCell,
LayerNormGRULayer,
LayerNormLSTM,
LayerNormLSTMCell,
LayerNormLSTMLayer,
)
def get_devices():
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda", 0))
return devices
def assert_allclose(a: torch.Tensor, b: torch.Tensor, atol=1e-6, **kwargs):
assert torch.allclose(
a, b, atol=atol, **kwargs
), f"{(a - b).abs().max()}, {a.numel()}"
def test_layernorm_lstm_cell_jit(device="cpu"):
input_size = 10
hidden_size = 20
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
cell = LayerNormLSTMCell(
input_size=input_size,
hidden_size=hidden_size,
bias=bias,
device=device,
)
torch.jit.script(cell)
def test_layernorm_lstm_cell_constructor(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
self_cell = LayerNormLSTMCell(
input_size,
hidden_size,
ln=nn.Identity,
device=device,
)
torch_cell = nn.LSTMCell(
input_size,
hidden_size,
).to(device)
for name, param in self_cell.named_parameters():
assert param.shape == getattr(torch_cell, name).shape
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
def test_layernorm_lstm_cell_with_projection_jit(device="cpu"):
input_size = 10
hidden_size = 20
proj_size = 5
self_cell = LayerNormLSTMCell(
input_size,
hidden_size,
proj_size=proj_size,
device=device,
)
torch.jit.script(self_cell)
def test_layernorm_lstm_cell_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_cell = LayerNormLSTMCell(
input_size,
hidden_size,
bias=bias,
ln=nn.Identity,
device=device,
)
torch_cell = nn.LSTMCell(
input_size,
hidden_size,
bias=bias,
).to(device)
with torch.no_grad():
for name, torch_param in torch_cell.named_parameters():
self_param = getattr(self_cell, name)
torch_param.copy_(self_param)
N = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, input_size, device=device).requires_grad_()
h = torch.rand(N, hidden_size, device=device)
c = torch.rand(N, hidden_size, device=device)
x_clone = x.detach().clone().requires_grad_()
self_h, self_c = self_cell(x.clone(), (h, c))
torch_h, torch_c = torch_cell(x_clone, (h, c))
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
self_hc = self_h * self_c
torch_hc = torch_h * torch_c
(
self_hc.reshape(-1) * torch.arange(self_hc.numel(), device=device)
).sum().backward()
(
torch_hc.reshape(-1) * torch.arange(torch_hc.numel(), device=device)
).sum().backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-3)
def test_layernorm_lstm_cell_with_projection_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
self_cell = LayerNormLSTMCell(
input_size,
hidden_size,
bias=bias,
ln=nn.Identity,
proj_size=proj_size,
device=device,
)
torch_cell = nn.LSTM(
input_size,
hidden_size,
bias=bias,
proj_size=proj_size,
batch_first=True,
).to(device)
with torch.no_grad():
for name, self_param in self_cell.named_parameters():
getattr(torch_cell, f"{name}_l0").copy_(self_param)
N = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, input_size, device=device).requires_grad_()
h = torch.rand(N, proj_size, device=device)
c = torch.rand(N, hidden_size, device=device)
x_clone = x.detach().clone().requires_grad_()
self_h, self_c = self_cell(x.clone(), (h, c))
_, (torch_h, torch_c) = torch_cell(
x_clone.unsqueeze(1), (h.unsqueeze(0), c.unsqueeze(0))
)
torch_h = torch_h.squeeze(0)
torch_c = torch_c.squeeze(0)
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
(self_h.sum() * self_c.sum()).backward()
(torch_h.sum() * torch_c.sum()).backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
def test_layernorm_lstm_layer_jit(device="cpu"):
input_size = 10
hidden_size = 20
layer = LayerNormLSTMLayer(
input_size,
hidden_size=hidden_size,
device=device,
)
torch.jit.script(layer)
def test_layernorm_lstm_layer_with_project_jit(device="cpu"):
input_size = 10
hidden_size = 20
proj_size = 5
layer = LayerNormLSTMLayer(
input_size,
hidden_size=hidden_size,
proj_size=proj_size,
device=device,
)
torch.jit.script(layer)
def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
self_layer = LayerNormLSTMLayer(
input_size,
hidden_size,
bias=bias,
proj_size=proj_size,
ln=nn.Identity,
device=device,
)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size, device=device).requires_grad_()
h = torch.rand(N, proj_size, device=device)
c = torch.rand(N, hidden_size, device=device)
x_clone = x.detach().clone().requires_grad_()
self_y, (self_h, self_c) = self_layer(x, (h, c))
torch_layer = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
proj_size=proj_size,
batch_first=True,
dropout=0,
bidirectional=False,
).to(device)
with torch.no_grad():
for name, self_param in self_layer.cell.named_parameters():
getattr(torch_layer, f"{name}_l0").copy_(self_param)
torch_y, (torch_h, torch_c) = torch_layer(
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
)
assert_allclose(self_y, torch_y)
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
self_y.sum().backward()
torch_y.sum().backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
def test_layernorm_lstm_layer_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_layer = LayerNormLSTMLayer(
input_size,
hidden_size,
bias=bias,
ln=nn.Identity,
device=device,
)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size, device=device).requires_grad_()
h = torch.rand(N, hidden_size, device=device)
c = torch.rand(N, hidden_size, device=device)
x_clone = x.detach().clone().requires_grad_()
self_y, (self_h, self_c) = self_layer(x, (h, c))
torch_layer = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
batch_first=True,
dropout=0,
bidirectional=False,
).to(device)
with torch.no_grad():
for name, self_param in self_layer.cell.named_parameters():
getattr(torch_layer, f"{name}_l0").copy_(self_param)
torch_y, (torch_h, torch_c) = torch_layer(
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
)
assert_allclose(self_y, torch_y)
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
self_hc = self_h * self_c
torch_hc = torch_h * torch_c
self_hc_sum = (
self_hc.reshape(-1) * torch.arange(self_hc.numel(), device=device)
).sum()
torch_hc_sum = (
torch_hc.reshape(-1) * torch.arange(torch_hc.numel(), device=device)
).sum()
self_y_sum = (
self_y.reshape(-1) * torch.arange(self_y.numel(), device=device)
).sum()
torch_y_sum = (
torch_y.reshape(-1) * torch.arange(torch_y.numel(), device=device)
).sum()
(self_hc_sum + self_y_sum).backward()
(torch_hc_sum + torch_y_sum).backward()
assert_allclose(x.grad, x_clone.grad, atol=0.1)
def test_layernorm_lstm_jit(device="cpu"):
input_size = 2
hidden_size = 3
num_layers = 4
bias = True
lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
device=device,
)
torch.jit.script(lstm)
def test_layernorm_lstm_with_projection_jit(device="cpu"):
input_size = 2
hidden_size = 5
proj_size = 3
num_layers = 4
bias = True
lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
proj_size=proj_size,
ln=nn.Identity,
device=device,
)
torch.jit.script(lstm)
def test_layernorm_lstm_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
device=device,
)
torch_lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
bidirectional=False,
).to(device)
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
with torch.no_grad():
for name, param in self_lstm.named_parameters():
# name has the form layers.0.cell.weight_hh
parts = name.split(".")
layer_num = parts[1]
getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size, device=device).requires_grad_()
hs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
cs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
states = list(zip(hs, cs))
x_clone = x.detach().clone().requires_grad_()
self_y, self_states = self_lstm(x, states)
h = torch.stack(hs)
c = torch.stack(cs)
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
assert_allclose(self_y, torch_y)
self_h = torch.stack([s[0] for s in self_states])
self_c = torch.stack([s[1] for s in self_states])
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
s = self_y.reshape(-1)
t = torch_y.reshape(-1)
s_sum = (s * torch.arange(s.numel(), device=device)).sum()
t_sum = (t * torch.arange(t.numel(), device=device)).sum()
shc_sum = s_sum + self_h.sum() + self_c.sum()
thc_sum = t_sum + torch_h.sum() + torch_c.sum()
shc_sum.backward()
thc_sum.backward()
assert_allclose(x.grad, x_clone.grad)
def test_layernorm_lstm_with_projection_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
proj_size=proj_size,
ln=nn.Identity,
device=device,
)
torch_lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
proj_size=proj_size,
batch_first=True,
bidirectional=False,
).to(device)
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
with torch.no_grad():
for name, param in self_lstm.named_parameters():
# name has the form layers.0.cell.weight_hh
parts = name.split(".")
layer_num = parts[1]
getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size, device=device).requires_grad_()
hs = [torch.rand(N, proj_size, device=device) for _ in range(num_layers)]
cs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
states = list(zip(hs, cs))
x_clone = x.detach().clone().requires_grad_()
self_y, self_states = self_lstm(x, states)
h = torch.stack(hs)
c = torch.stack(cs)
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
assert_allclose(self_y, torch_y)
self_h = torch.stack([s[0] for s in self_states])
self_c = torch.stack([s[1] for s in self_states])
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
s = self_y.reshape(-1)
t = torch_y.reshape(-1)
s_sum = (s * torch.arange(s.numel(), device=device)).sum()
t_sum = (t * torch.arange(t.numel(), device=device)).sum()
shc_sum = s_sum + self_h.sum() + self_c.sum()
thc_sum = t_sum + torch_h.sum() + torch_c.sum()
shc_sum.backward()
thc_sum.backward()
assert_allclose(x.grad, x_clone.grad)
def test_layernorm_gru_cell_jit(device="cpu"):
input_size = 10
hidden_size = 20
cell = LayerNormGRUCell(
input_size=input_size,
hidden_size=hidden_size,
bias=True,
device=device,
)
torch.jit.script(cell)
def test_layernorm_gru_cell_constructor(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
self_cell = LayerNormGRUCell(
input_size,
hidden_size,
ln=nn.Identity,
device=device,
)
torch_cell = nn.GRUCell(
input_size,
hidden_size,
).to(device)
for name, param in self_cell.named_parameters():
assert param.shape == getattr(torch_cell, name).shape
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
def test_layernorm_gru_cell_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_cell = LayerNormGRUCell(
input_size,
hidden_size,
bias=bias,
ln=nn.Identity,
device=device,
)
torch_cell = nn.GRUCell(
input_size,
hidden_size,
bias=bias,
).to(device)
with torch.no_grad():
for name, torch_param in torch_cell.named_parameters():
self_param = getattr(self_cell, name)
torch_param.copy_(self_param)
N = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, input_size, device=device).requires_grad_()
h = torch.rand(N, hidden_size, device=device)
x_clone = x.detach().clone().requires_grad_()
self_h = self_cell(x.clone(), h)
torch_h = torch_cell(x_clone, h)
assert_allclose(self_h, torch_h, atol=1e-5)
(
self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
).sum().backward()
(
torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
).sum().backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-3)
def test_layernorm_gru_layer_jit(device="cpu"):
input_size = 10
hidden_size = 20
layer = LayerNormGRULayer(
input_size,
hidden_size=hidden_size,
device=device,
)
torch.jit.script(layer)
def test_layernorm_gru_layer_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_layer = LayerNormGRULayer(
input_size,
hidden_size,
bias=bias,
ln=nn.Identity,
device=device,
)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size, device=device).requires_grad_()
h = torch.rand(N, hidden_size, device=device)
x_clone = x.detach().clone().requires_grad_()
self_y, self_h = self_layer(x, h.clone())
torch_layer = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
batch_first=True,
dropout=0,
bidirectional=False,
).to(device)
with torch.no_grad():
for name, self_param in self_layer.cell.named_parameters():
getattr(torch_layer, f"{name}_l0").copy_(self_param)
torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0))
assert_allclose(self_y, torch_y)
assert_allclose(self_h, torch_h)
self_y_sum = (
self_y.reshape(-1) * torch.arange(self_y.numel(), device=device)
).sum()
torch_y_sum = (
torch_y.reshape(-1) * torch.arange(torch_y.numel(), device=device)
).sum()
self_y_sum.backward()
torch_y_sum.backward()
assert_allclose(x.grad, x_clone.grad, atol=0.1)
def test_layernorm_gru_jit(device="cpu"):
input_size = 2
hidden_size = 3
num_layers = 4
bias = True
gru = LayerNormGRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
device=device,
)
torch.jit.script(gru)
def test_layernorm_gru_forward(device="cpu"):
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_gru = LayerNormGRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
device=device,
)
torch_gru = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
bidirectional=False,
).to(device)
assert len(self_gru.state_dict()) == len(torch_gru.state_dict())
with torch.no_grad():
for name, param in self_gru.named_parameters():
# name has the form layers.0.cell.weight_hh
parts = name.split(".")
layer_num = parts[1]
getattr(torch_gru, f"{parts[-1]}_l{layer_num}").copy_(param)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size, device=device).requires_grad_()
states = [
torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
]
x_clone = x.detach().clone().requires_grad_()
self_y, self_states = self_gru(x, states)
torch_y, torch_states = torch_gru(x_clone, torch.stack(states))
assert_allclose(self_y, torch_y)
self_states = torch.stack(self_states)
assert_allclose(self_states, torch_states)
s = self_y.reshape(-1)
t = torch_y.reshape(-1)
s_sum = (s * torch.arange(s.numel(), device=device)).sum()
t_sum = (t * torch.arange(t.numel(), device=device)).sum()
s_state_sum = s_sum + self_states.sum()
t_state_sum = t_sum + torch_states.sum()
s_state_sum.backward()
t_state_sum.backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-2)
def _test_lstm(device):
test_layernorm_lstm_cell_jit(device)
test_layernorm_lstm_cell_constructor(device)
test_layernorm_lstm_cell_with_projection_jit(device)
test_layernorm_lstm_cell_forward(device)
test_layernorm_lstm_cell_with_projection_forward(device)
#
test_layernorm_lstm_layer_jit(device)
test_layernorm_lstm_layer_with_project_jit(device)
test_layernorm_lstm_layer_forward(device)
test_layernorm_lstm_layer_with_projection_forward(device)
test_layernorm_lstm_jit(device)
test_layernorm_lstm_with_projection_jit(device)
test_layernorm_lstm_forward(device)
test_layernorm_lstm_with_projection_forward(device)
def _test_gru(device):
test_layernorm_gru_cell_jit(device)
test_layernorm_gru_cell_constructor(device)
test_layernorm_gru_cell_forward(device)
#
test_layernorm_gru_layer_jit(device)
test_layernorm_gru_layer_forward(device)
#
test_layernorm_gru_jit(device)
test_layernorm_gru_forward(device)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def main():
for device in get_devices():
print("device", device)
_test_lstm(device)
_test_gru(device)
if __name__ == "__main__":
torch.manual_seed(20211202)
main()

View File

@ -2,3 +2,4 @@ kaldilm
kaldialign
sentencepiece>=0.1.96
tensorboard
typeguard