mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Apply layer normalization to the output of each gate in GRU.
This commit is contained in:
parent
8a038b8f1a
commit
1d004ca966
@ -1,284 +0,0 @@
|
|||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
Apply layer normalization to the output of each gate in LSTM/GRU.
|
|
||||||
|
|
||||||
This file uses
|
|
||||||
https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
|
|
||||||
as a reference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf
|
|
||||||
class LayerNormLSTMCell(nn.Module):
|
|
||||||
"""This class places a `nn.LayerNorm` after the output of
|
|
||||||
each gate (right before the activation).
|
|
||||||
|
|
||||||
See the following paper for more details
|
|
||||||
|
|
||||||
'Improving RNN Transducer Modeling for End-to-End Speech Recognition'
|
|
||||||
https://arxiv.org/abs/1909.12415
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
bias: bool = True,
|
|
||||||
ln: nn.Module = nn.LayerNorm,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_size:
|
|
||||||
The number of expected features in the input `x`. `x` should
|
|
||||||
be of shape (batch_size, input_size).
|
|
||||||
hidden_size:
|
|
||||||
The number of features in the hidden state `h` and `c`.
|
|
||||||
Both `h` and `c` are of shape (batch_size, hidden_size).
|
|
||||||
bias:
|
|
||||||
If ``False``, then the cell does not use bias weights
|
|
||||||
`bias_ih` and `bias_hh`.
|
|
||||||
ln:
|
|
||||||
Defaults to `nn.LayerNorm`. The output of all gates are processed
|
|
||||||
by `ln`. We pass it as an argument so that we can replace it
|
|
||||||
with `nn.Identity` at the testing time.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
self.input_size = input_size
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.bias = bias
|
|
||||||
|
|
||||||
self.weight_ih = nn.Parameter(
|
|
||||||
torch.empty((4 * hidden_size, input_size), **factory_kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.weight_hh = nn.Parameter(
|
|
||||||
torch.empty((4 * hidden_size, hidden_size), **factory_kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
self.bias_ih = nn.Parameter(
|
|
||||||
torch.empty(4 * hidden_size, **factory_kwargs)
|
|
||||||
)
|
|
||||||
self.bias_hh = nn.Parameter(
|
|
||||||
torch.empty(4 * hidden_size, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias_ih", None)
|
|
||||||
self.register_parameter("bias_hh", None)
|
|
||||||
|
|
||||||
self.layernorm_i = ln(hidden_size)
|
|
||||||
self.layernorm_f = ln(hidden_size)
|
|
||||||
self.layernorm_cx = ln(hidden_size)
|
|
||||||
self.layernorm_cy = ln(hidden_size)
|
|
||||||
self.layernorm_o = ln(hidden_size)
|
|
||||||
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input:
|
|
||||||
A 2-D tensor of shape (batch_size, input_size).
|
|
||||||
state:
|
|
||||||
If not ``None``, it contains the hidden state (h, c); both
|
|
||||||
are of shape (batch_size, hidden_size). If ``None``, it uses
|
|
||||||
zeros for `h` and `c`.
|
|
||||||
Returns:
|
|
||||||
Return two tensors:
|
|
||||||
- `next_h`: It is of shape (batch_size, hidden_size) containing the
|
|
||||||
next hidden state for each element in the batch.
|
|
||||||
- `next_c`: It is of shape (batch_size, hidden_size) containing the
|
|
||||||
next cell state for each element in the batch.
|
|
||||||
"""
|
|
||||||
if state is None:
|
|
||||||
zeros = torch.zeros(
|
|
||||||
input.size(0),
|
|
||||||
self.hidden_size,
|
|
||||||
dtype=input.dtype,
|
|
||||||
device=input.device,
|
|
||||||
)
|
|
||||||
state = (zeros, zeros)
|
|
||||||
|
|
||||||
hx, cx = state
|
|
||||||
gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear(
|
|
||||||
hx, self.weight_hh, self.bias_hh
|
|
||||||
)
|
|
||||||
|
|
||||||
in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1)
|
|
||||||
|
|
||||||
in_gate = self.layernorm_i(in_gate)
|
|
||||||
forget_gate = self.layernorm_f(forget_gate)
|
|
||||||
cell_gate = self.layernorm_cx(cell_gate)
|
|
||||||
out_gate = self.layernorm_o(out_gate)
|
|
||||||
|
|
||||||
in_gate = torch.sigmoid(in_gate)
|
|
||||||
forget_gate = torch.sigmoid(forget_gate)
|
|
||||||
cell_gate = torch.tanh(cell_gate)
|
|
||||||
out_gate = torch.sigmoid(out_gate)
|
|
||||||
|
|
||||||
cy = (forget_gate * cx) + (in_gate * cell_gate)
|
|
||||||
cy = self.layernorm_cy(cy)
|
|
||||||
hy = out_gate * torch.tanh(cy)
|
|
||||||
|
|
||||||
return hy, cy
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
s = "{input_size}, {hidden_size}"
|
|
||||||
if "bias" in self.__dict__ and self.bias is not True:
|
|
||||||
s += ", bias={bias}"
|
|
||||||
return s.format(**self.__dict__)
|
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
|
||||||
stdv = 1.0 / math.sqrt(self.hidden_size)
|
|
||||||
for weight in self.parameters():
|
|
||||||
nn.init.uniform_(weight, -stdv, stdv)
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNormLSTMLayer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
bias: bool = True,
|
|
||||||
ln: nn.Module = nn.LayerNorm,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
See the args in LayerNormLSTMCell
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.cell = LayerNormLSTMCell(
|
|
||||||
input_size=input_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
bias=bias,
|
|
||||||
ln=ln,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
state: Tuple[torch.Tensor, torch.Tensor],
|
|
||||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input:
|
|
||||||
A 3-D tensor of shape (batch_size, seq_len, input_size).
|
|
||||||
Caution:
|
|
||||||
We use `batch_first=True` here.
|
|
||||||
state:
|
|
||||||
If not ``None``, it contains the hidden state (h, c) of this layer.
|
|
||||||
Both are of shape (batch_size, hidden_size).
|
|
||||||
Note:
|
|
||||||
We did not annotate `state` with `Optional[Tuple[...]]` since
|
|
||||||
torchscript will complain.
|
|
||||||
Return:
|
|
||||||
- output, a tensor of shape (batch_size, seq_len, hidden_size)
|
|
||||||
- (next_h, next_c) containing the hidden state of this layer
|
|
||||||
"""
|
|
||||||
inputs = input.unbind(1)
|
|
||||||
outputs = torch.jit.annotate(List[torch.Tensor], [])
|
|
||||||
for i in range(len(inputs)):
|
|
||||||
state = self.cell(inputs[i], state)
|
|
||||||
outputs += [state[0]]
|
|
||||||
return torch.stack(outputs, dim=1), state
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNormLSTM(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
hidden_size: int,
|
|
||||||
num_layers: int,
|
|
||||||
bias: bool = True,
|
|
||||||
ln: nn.Module = nn.LayerNorm,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
See the args in LSTMLayer.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
assert num_layers >= 1
|
|
||||||
factory_kwargs = dict(
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
bias=bias,
|
|
||||||
ln=ln,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
first_layer = LayerNormLSTMLayer(
|
|
||||||
input_size=input_size, **factory_kwargs
|
|
||||||
)
|
|
||||||
layers = [first_layer]
|
|
||||||
for i in range(1, num_layers):
|
|
||||||
layers.append(
|
|
||||||
LayerNormLSTMLayer(
|
|
||||||
input_size=hidden_size,
|
|
||||||
**factory_kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.layers = nn.ModuleList(layers)
|
|
||||||
self.num_layers = num_layers
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input: torch.Tensor,
|
|
||||||
states: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input:
|
|
||||||
A 3-D tensor of shape (batch_size, seq_len, input_size).
|
|
||||||
Caution:
|
|
||||||
We use `batch_first=True` here.
|
|
||||||
states:
|
|
||||||
One state per layer. Each entry contains the hidden state (h, c)
|
|
||||||
for a layer. Both are of shape (batch_size, hidden_size).
|
|
||||||
Returns:
|
|
||||||
Return a tuple containing:
|
|
||||||
|
|
||||||
- output: A tensor of shape (batch_size, seq_len, hidden_size)
|
|
||||||
- List[(next_h, next_c)] containing the hidden states for all layers
|
|
||||||
|
|
||||||
"""
|
|
||||||
output_states = torch.jit.annotate(
|
|
||||||
List[Tuple[torch.Tensor, torch.Tensor]], []
|
|
||||||
)
|
|
||||||
output = input
|
|
||||||
for i, rnn_layer in enumerate(self.layers):
|
|
||||||
state = states[i]
|
|
||||||
output, out_state = rnn_layer(output, state)
|
|
||||||
output_states += [out_state]
|
|
||||||
return output, output_states
|
|
@ -1,240 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from rnnt.rnn import LayerNormLSTM, LayerNormLSTMCell, LayerNormLSTMLayer
|
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_cell_jit():
|
|
||||||
input_size = 10
|
|
||||||
hidden_size = 20
|
|
||||||
cell = LayerNormLSTMCell(
|
|
||||||
input_size=input_size, hidden_size=hidden_size, bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.jit.script(cell)
|
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_cell_constructor():
|
|
||||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
|
|
||||||
self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity)
|
|
||||||
torch_cell = nn.LSTMCell(input_size, hidden_size)
|
|
||||||
|
|
||||||
for name, param in self_cell.named_parameters():
|
|
||||||
assert param.shape == getattr(torch_cell, name).shape
|
|
||||||
|
|
||||||
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
|
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_cell_forward():
|
|
||||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
|
||||||
|
|
||||||
self_cell = LayerNormLSTMCell(
|
|
||||||
input_size, hidden_size, bias=bias, ln=nn.Identity
|
|
||||||
)
|
|
||||||
torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
|
|
||||||
with torch.no_grad():
|
|
||||||
for name, torch_param in torch_cell.named_parameters():
|
|
||||||
self_param = getattr(self_cell, name)
|
|
||||||
torch_param.copy_(self_param)
|
|
||||||
|
|
||||||
N = torch.randint(low=2, high=100, size=(1,))
|
|
||||||
x = torch.rand(N, input_size).requires_grad_()
|
|
||||||
h = torch.rand(N, hidden_size)
|
|
||||||
c = torch.rand(N, hidden_size)
|
|
||||||
|
|
||||||
x_clone = x.detach().clone().requires_grad_()
|
|
||||||
|
|
||||||
self_h, self_c = self_cell(x.clone(), (h, c))
|
|
||||||
torch_h, torch_c = torch_cell(x_clone, (h, c))
|
|
||||||
|
|
||||||
assert torch.allclose(self_h, torch_h)
|
|
||||||
assert torch.allclose(self_c, torch_c)
|
|
||||||
|
|
||||||
self_hc = self_h * self_c
|
|
||||||
torch_hc = torch_h * torch_c
|
|
||||||
(self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward()
|
|
||||||
(torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward()
|
|
||||||
|
|
||||||
assert torch.allclose(x.grad, x_clone.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def test_lstm_layer_jit():
|
|
||||||
input_size = 10
|
|
||||||
hidden_size = 20
|
|
||||||
layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size)
|
|
||||||
torch.jit.script(layer)
|
|
||||||
|
|
||||||
|
|
||||||
def test_lstm_layer_forward():
|
|
||||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
|
||||||
self_layer = LayerNormLSTMLayer(
|
|
||||||
input_size,
|
|
||||||
hidden_size,
|
|
||||||
bias=bias,
|
|
||||||
ln=nn.Identity,
|
|
||||||
)
|
|
||||||
|
|
||||||
N = torch.randint(low=2, high=100, size=(1,))
|
|
||||||
T = torch.randint(low=2, high=100, size=(1,))
|
|
||||||
|
|
||||||
x = torch.rand(N, T, input_size).requires_grad_()
|
|
||||||
h = torch.rand(N, hidden_size)
|
|
||||||
c = torch.rand(N, hidden_size)
|
|
||||||
|
|
||||||
x_clone = x.detach().clone().requires_grad_()
|
|
||||||
|
|
||||||
self_y, (self_h, self_c) = self_layer(x, (h, c))
|
|
||||||
|
|
||||||
# now for pytorch
|
|
||||||
torch_layer = nn.LSTM(
|
|
||||||
input_size=input_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
num_layers=1,
|
|
||||||
bias=bias,
|
|
||||||
batch_first=True,
|
|
||||||
dropout=0,
|
|
||||||
bidirectional=False,
|
|
||||||
)
|
|
||||||
with torch.no_grad():
|
|
||||||
for name, self_param in self_layer.cell.named_parameters():
|
|
||||||
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
|
||||||
|
|
||||||
torch_y, (torch_h, torch_c) = torch_layer(
|
|
||||||
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
|
||||||
)
|
|
||||||
assert torch.allclose(self_y, torch_y)
|
|
||||||
assert torch.allclose(self_h, torch_h)
|
|
||||||
assert torch.allclose(self_c, torch_c)
|
|
||||||
|
|
||||||
self_hc = self_h * self_c
|
|
||||||
torch_hc = torch_h * torch_c
|
|
||||||
self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum()
|
|
||||||
torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum()
|
|
||||||
|
|
||||||
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
|
||||||
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
|
||||||
|
|
||||||
(self_hc_sum * self_y_sum).backward()
|
|
||||||
(torch_hc_sum * torch_y_sum).backward()
|
|
||||||
|
|
||||||
assert torch.allclose(x.grad, x_clone.grad, rtol=0.1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stacked_lstm_jit():
|
|
||||||
input_size = 2
|
|
||||||
hidden_size = 3
|
|
||||||
num_layers = 4
|
|
||||||
bias = True
|
|
||||||
|
|
||||||
lstm = LayerNormLSTM(
|
|
||||||
input_size=input_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
num_layers=num_layers,
|
|
||||||
bias=bias,
|
|
||||||
ln=nn.Identity,
|
|
||||||
)
|
|
||||||
torch.jit.script(lstm)
|
|
||||||
|
|
||||||
|
|
||||||
def test_stacked_lstm_forward():
|
|
||||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
|
|
||||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
|
||||||
|
|
||||||
self_lstm = LayerNormLSTM(
|
|
||||||
input_size=input_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
num_layers=num_layers,
|
|
||||||
bias=bias,
|
|
||||||
ln=nn.Identity,
|
|
||||||
)
|
|
||||||
torch_lstm = nn.LSTM(
|
|
||||||
input_size=input_size,
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
num_layers=num_layers,
|
|
||||||
bias=bias,
|
|
||||||
batch_first=True,
|
|
||||||
bidirectional=False,
|
|
||||||
)
|
|
||||||
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
|
||||||
with torch.no_grad():
|
|
||||||
for name, param in self_lstm.named_parameters():
|
|
||||||
# name has the form layers.0.cell.weight_hh
|
|
||||||
parts = name.split(".")
|
|
||||||
layer_num = parts[1]
|
|
||||||
getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param)
|
|
||||||
|
|
||||||
N = torch.randint(low=2, high=100, size=(1,))
|
|
||||||
T = torch.randint(low=2, high=100, size=(1,))
|
|
||||||
|
|
||||||
x = torch.rand(N, T, input_size).requires_grad_()
|
|
||||||
hs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
|
||||||
cs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
|
||||||
states = list(zip(hs, cs))
|
|
||||||
|
|
||||||
x_clone = x.detach().clone().requires_grad_()
|
|
||||||
|
|
||||||
self_y, self_states = self_lstm(x, states)
|
|
||||||
|
|
||||||
h = torch.stack(hs)
|
|
||||||
c = torch.stack(cs)
|
|
||||||
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
|
|
||||||
|
|
||||||
assert torch.allclose(self_y, torch_y)
|
|
||||||
|
|
||||||
self_h = torch.stack([s[0] for s in self_states])
|
|
||||||
self_c = torch.stack([s[1] for s in self_states])
|
|
||||||
|
|
||||||
assert torch.allclose(self_h, torch_h)
|
|
||||||
assert torch.allclose(self_c, torch_c)
|
|
||||||
|
|
||||||
s = self_y.reshape(-1)
|
|
||||||
t = torch_y.reshape(-1)
|
|
||||||
|
|
||||||
s_sum = (s * torch.arange(s.numel())).sum()
|
|
||||||
t_sum = (t * torch.arange(t.numel())).sum()
|
|
||||||
shc_sum = s_sum * self_h.sum() * self_c.sum()
|
|
||||||
thc_sum = t_sum * torch_h.sum() * torch_c.sum()
|
|
||||||
|
|
||||||
shc_sum.backward()
|
|
||||||
thc_sum.backward()
|
|
||||||
assert torch.allclose(x.grad, x_clone.grad)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
test_layernorm_lstm_cell_jit()
|
|
||||||
test_layernorm_lstm_cell_constructor()
|
|
||||||
test_layernorm_lstm_cell_forward()
|
|
||||||
#
|
|
||||||
test_lstm_layer_jit()
|
|
||||||
test_lstm_layer_forward()
|
|
||||||
#
|
|
||||||
test_stacked_lstm_jit()
|
|
||||||
test_stacked_lstm_forward()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
610
egs/librispeech/ASR/transducer/rnn.py
Normal file
610
egs/librispeech/ASR/transducer/rnn.py
Normal file
@ -0,0 +1,610 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Apply layer normalization to the output of each gate in LSTM/GRU.
|
||||||
|
|
||||||
|
This file uses
|
||||||
|
https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
|
||||||
|
as a reference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf
|
||||||
|
class LayerNormLSTMCell(nn.Module):
|
||||||
|
"""This class places a `nn.LayerNorm` after the output of
|
||||||
|
each gate (right before the activation).
|
||||||
|
|
||||||
|
See the following paper for more details
|
||||||
|
|
||||||
|
'Improving RNN Transducer Modeling for End-to-End Speech Recognition'
|
||||||
|
https://arxiv.org/abs/1909.12415
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> cell = LayerNormLSTMCell(10, 20)
|
||||||
|
>>> input = torch.rand(5, 10)
|
||||||
|
>>> h0 = torch.rand(5, 20)
|
||||||
|
>>> c0 = torch.rand(5, 20)
|
||||||
|
>>> h1, c1 = cell(input, (h0, c0))
|
||||||
|
>>> output = h1
|
||||||
|
>>> h1.shape
|
||||||
|
torch.Size([5, 20])
|
||||||
|
>>> c1.shape
|
||||||
|
torch.Size([5, 20])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
ln: nn.Module = nn.LayerNorm,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_size:
|
||||||
|
The number of expected features in the input `x`. `x` should
|
||||||
|
be of shape (batch_size, input_size).
|
||||||
|
hidden_size:
|
||||||
|
The number of features in the hidden state `h` and `c`.
|
||||||
|
Both `h` and `c` are of shape (batch_size, hidden_size).
|
||||||
|
bias:
|
||||||
|
If ``False``, then the cell does not use bias weights
|
||||||
|
`bias_ih` and `bias_hh`.
|
||||||
|
ln:
|
||||||
|
Defaults to `nn.LayerNorm`. The output of all gates are processed
|
||||||
|
by `ln`. We pass it as an argument so that we can replace it
|
||||||
|
with `nn.Identity` at the testing time.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
self.input_size = input_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
self.weight_ih = nn.Parameter(
|
||||||
|
torch.empty((4 * hidden_size, input_size), **factory_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.weight_hh = nn.Parameter(
|
||||||
|
torch.empty((4 * hidden_size, hidden_size), **factory_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias_ih = nn.Parameter(
|
||||||
|
torch.empty(4 * hidden_size, **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.bias_hh = nn.Parameter(
|
||||||
|
torch.empty(4 * hidden_size, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias_ih", None)
|
||||||
|
self.register_parameter("bias_hh", None)
|
||||||
|
|
||||||
|
self.layernorm_i = ln(hidden_size)
|
||||||
|
self.layernorm_f = ln(hidden_size)
|
||||||
|
self.layernorm_cx = ln(hidden_size)
|
||||||
|
self.layernorm_cy = ln(hidden_size)
|
||||||
|
self.layernorm_o = ln(hidden_size)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input:
|
||||||
|
A 2-D tensor of shape (batch_size, input_size).
|
||||||
|
state:
|
||||||
|
If not ``None``, it contains the hidden state (h, c) for each
|
||||||
|
element in the batch. Both are of shape (batch_size, hidden_size).
|
||||||
|
If ``None``, it uses zeros for `h` and `c`.
|
||||||
|
Returns:
|
||||||
|
Return two tensors:
|
||||||
|
- `next_h`: It is of shape (batch_size, hidden_size) containing the
|
||||||
|
next hidden state for each element in the batch.
|
||||||
|
- `next_c`: It is of shape (batch_size, hidden_size) containing the
|
||||||
|
next cell state for each element in the batch.
|
||||||
|
"""
|
||||||
|
if state is None:
|
||||||
|
zeros = torch.zeros(
|
||||||
|
input.size(0),
|
||||||
|
self.hidden_size,
|
||||||
|
dtype=input.dtype,
|
||||||
|
device=input.device,
|
||||||
|
)
|
||||||
|
state = (zeros, zeros)
|
||||||
|
|
||||||
|
hx, cx = state
|
||||||
|
gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear(
|
||||||
|
hx, self.weight_hh, self.bias_hh
|
||||||
|
)
|
||||||
|
|
||||||
|
in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1)
|
||||||
|
|
||||||
|
in_gate = self.layernorm_i(in_gate)
|
||||||
|
forget_gate = self.layernorm_f(forget_gate)
|
||||||
|
cell_gate = self.layernorm_cx(cell_gate)
|
||||||
|
out_gate = self.layernorm_o(out_gate)
|
||||||
|
|
||||||
|
in_gate = torch.sigmoid(in_gate)
|
||||||
|
forget_gate = torch.sigmoid(forget_gate)
|
||||||
|
cell_gate = torch.tanh(cell_gate)
|
||||||
|
out_gate = torch.sigmoid(out_gate)
|
||||||
|
|
||||||
|
cy = (forget_gate * cx) + (in_gate * cell_gate)
|
||||||
|
cy = self.layernorm_cy(cy)
|
||||||
|
hy = out_gate * torch.tanh(cy)
|
||||||
|
|
||||||
|
return hy, cy
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = "{input_size}, {hidden_size}"
|
||||||
|
if "bias" in self.__dict__ and self.bias is not True:
|
||||||
|
s += ", bias={bias}"
|
||||||
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
stdv = 1.0 / math.sqrt(self.hidden_size)
|
||||||
|
for weight in self.parameters():
|
||||||
|
nn.init.uniform_(weight, -stdv, stdv)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormLSTMLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> layer = LayerNormLSTMLayer(10, 20)
|
||||||
|
>>> input = torch.rand(2, 5, 10)
|
||||||
|
>>> h0 = torch.rand(2, 20)
|
||||||
|
>>> c0 = torch.rand(2, 20)
|
||||||
|
>>> output, (hn, cn) = layer(input, (h0, c0))
|
||||||
|
>>> output.shape
|
||||||
|
torch.Size([2, 5, 20])
|
||||||
|
>>> hn.shape
|
||||||
|
torch.Size([2, 20])
|
||||||
|
>>> cn.shape
|
||||||
|
torch.Size([2, 20])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
ln: nn.Module = nn.LayerNorm,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
See the args in LayerNormLSTMCell
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.cell = LayerNormLSTMCell(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
ln=ln,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
state: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input:
|
||||||
|
A 3-D tensor of shape (batch_size, seq_len, input_size).
|
||||||
|
Caution:
|
||||||
|
We use `batch_first=True` here.
|
||||||
|
state:
|
||||||
|
If not ``None``, it contains the hidden state (h, c) of this layer.
|
||||||
|
Both are of shape (batch_size, hidden_size).
|
||||||
|
Note:
|
||||||
|
We did not annotate `state` with `Optional[Tuple[...]]` since
|
||||||
|
torchscript will complain.
|
||||||
|
Return:
|
||||||
|
- output, a tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
- (next_h, next_c) containing the hidden state of this layer
|
||||||
|
"""
|
||||||
|
inputs = input.unbind(1)
|
||||||
|
outputs = torch.jit.annotate(List[torch.Tensor], [])
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
state = self.cell(inputs[i], state)
|
||||||
|
outputs.append(state[0])
|
||||||
|
return torch.stack(outputs, dim=1), state
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormLSTM(nn.Module):
|
||||||
|
"""
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> lstm = LayerNormLSTM(10, 20, 8)
|
||||||
|
>>> input = torch.rand(2, 3, 10)
|
||||||
|
>>> h0 = torch.rand(8, 2, 20).unbind(0)
|
||||||
|
>>> c0 = torch.rand(8, 2, 20).unbind(0)
|
||||||
|
>>> states = list(zip(h0, c0))
|
||||||
|
>>> output, next_states = lstm(input, states)
|
||||||
|
>>> output.shape
|
||||||
|
torch.Size([2, 3, 20])
|
||||||
|
>>> hn = torch.stack([s[0] for s in next_states])
|
||||||
|
>>> cn = torch.stack([s[1] for s in next_states])
|
||||||
|
>>> hn.shape
|
||||||
|
torch.Size([8, 2, 20])
|
||||||
|
>>> cn.shape
|
||||||
|
torch.Size([8, 2, 20])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
bias: bool = True,
|
||||||
|
ln: nn.Module = nn.LayerNorm,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
See the args in LayerNormLSTMLayer.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert num_layers >= 1
|
||||||
|
factory_kwargs = dict(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
ln=ln,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
first_layer = LayerNormLSTMLayer(
|
||||||
|
input_size=input_size, **factory_kwargs
|
||||||
|
)
|
||||||
|
layers = [first_layer]
|
||||||
|
for i in range(1, num_layers):
|
||||||
|
layers.append(
|
||||||
|
LayerNormLSTMLayer(
|
||||||
|
input_size=hidden_size,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
states: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input:
|
||||||
|
A 3-D tensor of shape (batch_size, seq_len, input_size).
|
||||||
|
Caution:
|
||||||
|
We use `batch_first=True` here.
|
||||||
|
states:
|
||||||
|
One state per layer. Each entry contains the hidden state (h, c)
|
||||||
|
for a layer. Both are of shape (batch_size, hidden_size).
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
|
||||||
|
- output: A tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
- List[(next_h, next_c)] containing the hidden states for all layers
|
||||||
|
|
||||||
|
"""
|
||||||
|
output_states = torch.jit.annotate(
|
||||||
|
List[Tuple[torch.Tensor, torch.Tensor]], []
|
||||||
|
)
|
||||||
|
output = input
|
||||||
|
for i, rnn_layer in enumerate(self.layers):
|
||||||
|
state = states[i]
|
||||||
|
output, out_state = rnn_layer(output, state)
|
||||||
|
output_states += [out_state]
|
||||||
|
return output, output_states
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormGRUCell(nn.Module):
|
||||||
|
"""This class places a `nn.LayerNorm` after the output of
|
||||||
|
each gate (right before the activation).
|
||||||
|
|
||||||
|
See the following paper for more details
|
||||||
|
|
||||||
|
'Improving RNN Transducer Modeling for End-to-End Speech Recognition'
|
||||||
|
https://arxiv.org/abs/1909.12415
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> cell = LayerNormGRUCell(10, 20)
|
||||||
|
>>> input = torch.rand(2, 10)
|
||||||
|
>>> h0 = torch.rand(2, 20)
|
||||||
|
>>> hn = cell(input, h0)
|
||||||
|
>>> hn.shape
|
||||||
|
torch.Size([2, 20])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
ln: nn.Module = nn.LayerNorm,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_size:
|
||||||
|
The number of expected features in the input `x`. `x` should
|
||||||
|
be of shape (batch_size, input_size).
|
||||||
|
hidden_size:
|
||||||
|
The number of features in the hidden state `h` and `c`.
|
||||||
|
Both `h` and `c` are of shape (batch_size, hidden_size).
|
||||||
|
bias:
|
||||||
|
If ``False``, then the cell does not use bias weights
|
||||||
|
`bias_ih` and `bias_hh`.
|
||||||
|
ln:
|
||||||
|
Defaults to `nn.LayerNorm`. The output of all gates are processed
|
||||||
|
by `ln`. We pass it as an argument so that we can replace it
|
||||||
|
with `nn.Identity` at the testing time.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
self.input_size = input_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
self.weight_ih = nn.Parameter(
|
||||||
|
torch.empty((3 * hidden_size, input_size), **factory_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.weight_hh = nn.Parameter(
|
||||||
|
torch.empty((3 * hidden_size, hidden_size), **factory_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias_ih = nn.Parameter(
|
||||||
|
torch.empty(3 * hidden_size, **factory_kwargs)
|
||||||
|
)
|
||||||
|
self.bias_hh = nn.Parameter(
|
||||||
|
torch.empty(3 * hidden_size, **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias_ih", None)
|
||||||
|
self.register_parameter("bias_hh", None)
|
||||||
|
|
||||||
|
self.layernorm_r = ln(hidden_size)
|
||||||
|
self.layernorm_i = ln(hidden_size)
|
||||||
|
self.layernorm_n = ln(hidden_size)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
hx: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input:
|
||||||
|
A 2-D tensor of shape (batch_size, input_size) containing
|
||||||
|
input features.
|
||||||
|
hx:
|
||||||
|
If not `None`, it is a tensor of shape (batch_size, hidden_size)
|
||||||
|
containing the initial hidden state for each element in the batch.
|
||||||
|
If `None`, it uses zeros for the hidden state.
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (batch_size, hidden_size) containing the
|
||||||
|
next hidden state for each element in the batch
|
||||||
|
"""
|
||||||
|
if hx is None:
|
||||||
|
hx = torch.zeros(
|
||||||
|
input.size(0),
|
||||||
|
self.hidden_size,
|
||||||
|
dtype=input.dtype,
|
||||||
|
device=input.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
i_r, i_i, i_n = F.linear(input, self.weight_ih, self.bias_ih).chunk(
|
||||||
|
chunks=3, dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
h_r, h_i, h_n = F.linear(hx, self.weight_hh, self.bias_hh).chunk(
|
||||||
|
chunks=3, dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
reset_gate = torch.sigmoid(self.layernorm_r(i_r + h_r))
|
||||||
|
input_gate = torch.sigmoid(self.layernorm_i(i_i + h_i))
|
||||||
|
new_gate = torch.tanh(self.layernorm_n(i_n + reset_gate * h_n))
|
||||||
|
|
||||||
|
# hy = (1 - input_gate) * new_gate + input_gate * hx
|
||||||
|
# = new_gate - input_gate * new_gate + input_gate * hx
|
||||||
|
# = new_gate + input_gate * (hx - new_gate)
|
||||||
|
hy = new_gate + input_gate * (hx - new_gate)
|
||||||
|
|
||||||
|
return hy
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = "{input_size}, {hidden_size}"
|
||||||
|
if "bias" in self.__dict__ and self.bias is not True:
|
||||||
|
s += ", bias={bias}"
|
||||||
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
stdv = 1.0 / math.sqrt(self.hidden_size)
|
||||||
|
for weight in self.parameters():
|
||||||
|
nn.init.uniform_(weight, -stdv, stdv)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormGRULayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> layer = LayerNormGRULayer(10, 20)
|
||||||
|
>>> input = torch.rand(2, 3, 10)
|
||||||
|
>>> hx = torch.rand(2, 20)
|
||||||
|
>>> output, hn = layer(input, hx)
|
||||||
|
>>> output.shape
|
||||||
|
torch.Size([2, 3, 20])
|
||||||
|
>>> hn.shape
|
||||||
|
torch.Size([2, 20])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
ln: nn.Module = nn.LayerNorm,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
See the args in LayerNormGRUCell
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.cell = LayerNormGRUCell(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
ln=ln,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
hx: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input:
|
||||||
|
A 3-D tensor of shape (batch_size, seq_len, input_size).
|
||||||
|
Caution:
|
||||||
|
We use `batch_first=True` here.
|
||||||
|
hx:
|
||||||
|
If not ``None``, it is a tensor of shape (batch_size, hidden_size)
|
||||||
|
containing the hidden state for each element in the batch.
|
||||||
|
Return:
|
||||||
|
- output, a tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
- next_h, a tensor of shape (batch_size, hidden_size) containing the
|
||||||
|
final hidden state for each element in the batch.
|
||||||
|
"""
|
||||||
|
inputs = input.unbind(1)
|
||||||
|
outputs = torch.jit.annotate(List[torch.Tensor], [])
|
||||||
|
next_h = hx
|
||||||
|
for i in range(len(inputs)):
|
||||||
|
next_h = self.cell(inputs[i], next_h)
|
||||||
|
outputs.append(next_h)
|
||||||
|
return torch.stack(outputs, dim=1), next_h
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNormGRU(nn.Module):
|
||||||
|
"""
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> input = torch.rand(2, 3, 10)
|
||||||
|
>>> h0 = torch.rand(8, 2, 20)
|
||||||
|
>>> states = h0.unbind(0)
|
||||||
|
>>> output, next_states = gru(input, states)
|
||||||
|
>>> output.shape
|
||||||
|
torch.Size([2, 3, 20])
|
||||||
|
>>> hn = torch.stack(next_states)
|
||||||
|
>>> hn.shape
|
||||||
|
torch.Size([8, 2, 20])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
num_layers: int,
|
||||||
|
bias: bool = True,
|
||||||
|
ln: nn.Module = nn.LayerNorm,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
See the args in LayerNormGRULayer.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert num_layers >= 1
|
||||||
|
factory_kwargs = dict(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
ln=ln,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
first_layer = LayerNormGRULayer(input_size=input_size, **factory_kwargs)
|
||||||
|
layers = [first_layer]
|
||||||
|
for i in range(1, num_layers):
|
||||||
|
layers.append(
|
||||||
|
LayerNormGRULayer(
|
||||||
|
input_size=hidden_size,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
states: List[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input:
|
||||||
|
A tensor of shape (batch_size, seq_len, input_size) containing
|
||||||
|
input features.
|
||||||
|
Caution:
|
||||||
|
We use `batch_first=True` here.
|
||||||
|
states:
|
||||||
|
One state per layer. Each entry contains the hidden state for each
|
||||||
|
element in the batch. Each hidden state is of shape
|
||||||
|
(batch_size, hidden_size)
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
|
||||||
|
- output: A tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
- List[next_state] containing the final hidden states for each
|
||||||
|
element in the batch
|
||||||
|
|
||||||
|
"""
|
||||||
|
output_states = torch.jit.annotate(List[torch.Tensor], [])
|
||||||
|
output = input
|
||||||
|
for i, rnn_layer in enumerate(self.layers):
|
||||||
|
state = states[i]
|
||||||
|
output, out_state = rnn_layer(output, state)
|
||||||
|
output_states += [out_state]
|
||||||
|
return output, output_states
|
453
egs/librispeech/ASR/transducer/test_rnn.py
Executable file
453
egs/librispeech/ASR/transducer/test_rnn.py
Executable file
@ -0,0 +1,453 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transducer.rnn import (
|
||||||
|
LayerNormGRU,
|
||||||
|
LayerNormGRUCell,
|
||||||
|
LayerNormGRULayer,
|
||||||
|
LayerNormLSTM,
|
||||||
|
LayerNormLSTMCell,
|
||||||
|
LayerNormLSTMLayer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_allclose(a: torch.Tensor, b: torch.Tensor, **kwargs):
|
||||||
|
assert torch.allclose(a, b, **kwargs), f"{(a - b).abs().max()}, {a.numel()}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_cell_jit():
|
||||||
|
input_size = 10
|
||||||
|
hidden_size = 20
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
cell = LayerNormLSTMCell(
|
||||||
|
input_size=input_size, hidden_size=hidden_size, bias=bias
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.jit.script(cell)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_cell_constructor():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
|
||||||
|
self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity)
|
||||||
|
torch_cell = nn.LSTMCell(input_size, hidden_size)
|
||||||
|
|
||||||
|
for name, param in self_cell.named_parameters():
|
||||||
|
assert param.shape == getattr(torch_cell, name).shape
|
||||||
|
|
||||||
|
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_cell_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
|
||||||
|
self_cell = LayerNormLSTMCell(
|
||||||
|
input_size, hidden_size, bias=bias, ln=nn.Identity
|
||||||
|
)
|
||||||
|
torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, torch_param in torch_cell.named_parameters():
|
||||||
|
self_param = getattr(self_cell, name)
|
||||||
|
torch_param.copy_(self_param)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
x = torch.rand(N, input_size).requires_grad_()
|
||||||
|
h = torch.rand(N, hidden_size)
|
||||||
|
c = torch.rand(N, hidden_size)
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_h, self_c = self_cell(x.clone(), (h, c))
|
||||||
|
torch_h, torch_c = torch_cell(x_clone, (h, c))
|
||||||
|
|
||||||
|
assert_allclose(self_h, torch_h)
|
||||||
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
|
self_hc = self_h * self_c
|
||||||
|
torch_hc = torch_h * torch_c
|
||||||
|
(self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward()
|
||||||
|
(torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward()
|
||||||
|
|
||||||
|
assert_allclose(x.grad, x_clone.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_layer_jit():
|
||||||
|
input_size = 10
|
||||||
|
hidden_size = 20
|
||||||
|
layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size)
|
||||||
|
torch.jit.script(layer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_layer_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
self_layer = LayerNormLSTMLayer(
|
||||||
|
input_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
T = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
|
||||||
|
x = torch.rand(N, T, input_size).requires_grad_()
|
||||||
|
h = torch.rand(N, hidden_size)
|
||||||
|
c = torch.rand(N, hidden_size)
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_y, (self_h, self_c) = self_layer(x, (h, c))
|
||||||
|
|
||||||
|
torch_layer = nn.LSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=1,
|
||||||
|
bias=bias,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=0,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, self_param in self_layer.cell.named_parameters():
|
||||||
|
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
||||||
|
|
||||||
|
torch_y, (torch_h, torch_c) = torch_layer(
|
||||||
|
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
||||||
|
)
|
||||||
|
assert_allclose(self_y, torch_y)
|
||||||
|
assert_allclose(self_h, torch_h)
|
||||||
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
|
self_hc = self_h * self_c
|
||||||
|
torch_hc = torch_h * torch_c
|
||||||
|
self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum()
|
||||||
|
torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum()
|
||||||
|
|
||||||
|
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
||||||
|
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
||||||
|
|
||||||
|
(self_hc_sum * self_y_sum).backward()
|
||||||
|
(torch_hc_sum * torch_y_sum).backward()
|
||||||
|
|
||||||
|
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_jit():
|
||||||
|
input_size = 2
|
||||||
|
hidden_size = 3
|
||||||
|
num_layers = 4
|
||||||
|
bias = True
|
||||||
|
|
||||||
|
lstm = LayerNormLSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
torch.jit.script(lstm)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
|
||||||
|
self_lstm = LayerNormLSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
torch_lstm = nn.LSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in self_lstm.named_parameters():
|
||||||
|
# name has the form layers.0.cell.weight_hh
|
||||||
|
parts = name.split(".")
|
||||||
|
layer_num = parts[1]
|
||||||
|
getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
T = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
|
||||||
|
x = torch.rand(N, T, input_size).requires_grad_()
|
||||||
|
hs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||||
|
cs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||||
|
states = list(zip(hs, cs))
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_y, self_states = self_lstm(x, states)
|
||||||
|
|
||||||
|
h = torch.stack(hs)
|
||||||
|
c = torch.stack(cs)
|
||||||
|
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
|
||||||
|
|
||||||
|
assert_allclose(self_y, torch_y)
|
||||||
|
|
||||||
|
self_h = torch.stack([s[0] for s in self_states])
|
||||||
|
self_c = torch.stack([s[1] for s in self_states])
|
||||||
|
|
||||||
|
assert_allclose(self_h, torch_h)
|
||||||
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
|
s = self_y.reshape(-1)
|
||||||
|
t = torch_y.reshape(-1)
|
||||||
|
|
||||||
|
s_sum = (s * torch.arange(s.numel())).sum()
|
||||||
|
t_sum = (t * torch.arange(t.numel())).sum()
|
||||||
|
shc_sum = s_sum * self_h.sum() * self_c.sum()
|
||||||
|
thc_sum = t_sum * torch_h.sum() * torch_c.sum()
|
||||||
|
|
||||||
|
shc_sum.backward()
|
||||||
|
thc_sum.backward()
|
||||||
|
assert_allclose(x.grad, x_clone.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_gru_cell_jit():
|
||||||
|
input_size = 10
|
||||||
|
hidden_size = 20
|
||||||
|
cell = LayerNormGRUCell(
|
||||||
|
input_size=input_size, hidden_size=hidden_size, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.jit.script(cell)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_gru_cell_constructor():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
|
||||||
|
self_cell = LayerNormGRUCell(input_size, hidden_size, ln=nn.Identity)
|
||||||
|
torch_cell = nn.GRUCell(input_size, hidden_size)
|
||||||
|
|
||||||
|
for name, param in self_cell.named_parameters():
|
||||||
|
assert param.shape == getattr(torch_cell, name).shape
|
||||||
|
|
||||||
|
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_gru_cell_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
|
||||||
|
self_cell = LayerNormGRUCell(
|
||||||
|
input_size, hidden_size, bias=bias, ln=nn.Identity
|
||||||
|
)
|
||||||
|
torch_cell = nn.GRUCell(input_size, hidden_size, bias=bias)
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, torch_param in torch_cell.named_parameters():
|
||||||
|
self_param = getattr(self_cell, name)
|
||||||
|
torch_param.copy_(self_param)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
x = torch.rand(N, input_size).requires_grad_()
|
||||||
|
h = torch.rand(N, hidden_size)
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_h = self_cell(x.clone(), h)
|
||||||
|
torch_h = torch_cell(x_clone, h)
|
||||||
|
|
||||||
|
assert_allclose(self_h, torch_h, atol=1e-5)
|
||||||
|
|
||||||
|
(self_h.reshape(-1) * torch.arange(self_h.numel())).sum().backward()
|
||||||
|
(torch_h.reshape(-1) * torch.arange(torch_h.numel())).sum().backward()
|
||||||
|
|
||||||
|
assert_allclose(x.grad, x_clone.grad, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_gru_layer_jit():
|
||||||
|
input_size = 10
|
||||||
|
hidden_size = 20
|
||||||
|
layer = LayerNormGRULayer(input_size, hidden_size=hidden_size)
|
||||||
|
torch.jit.script(layer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_gru_layer_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
self_layer = LayerNormGRULayer(
|
||||||
|
input_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
T = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
|
||||||
|
x = torch.rand(N, T, input_size).requires_grad_()
|
||||||
|
h = torch.rand(N, hidden_size)
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_y, self_h = self_layer(x, h.clone())
|
||||||
|
|
||||||
|
torch_layer = nn.GRU(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=1,
|
||||||
|
bias=bias,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=0,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, self_param in self_layer.cell.named_parameters():
|
||||||
|
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
||||||
|
|
||||||
|
torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0))
|
||||||
|
assert_allclose(self_y, torch_y, atol=1e-6)
|
||||||
|
assert_allclose(self_h, torch_h)
|
||||||
|
|
||||||
|
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
||||||
|
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
||||||
|
|
||||||
|
self_y_sum.backward()
|
||||||
|
torch_y_sum.backward()
|
||||||
|
|
||||||
|
assert_allclose(x.grad, x_clone.grad, atol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_gru_jit():
|
||||||
|
input_size = 2
|
||||||
|
hidden_size = 3
|
||||||
|
num_layers = 4
|
||||||
|
bias = True
|
||||||
|
|
||||||
|
gru = LayerNormGRU(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
torch.jit.script(gru)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_gru_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
|
||||||
|
self_gru = LayerNormGRU(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
torch_gru = nn.GRU(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
assert len(self_gru.state_dict()) == len(torch_gru.state_dict())
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in self_gru.named_parameters():
|
||||||
|
# name has the form layers.0.cell.weight_hh
|
||||||
|
parts = name.split(".")
|
||||||
|
layer_num = parts[1]
|
||||||
|
getattr(torch_gru, f"{parts[-1]}_l{layer_num}").copy_(param)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
T = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
|
||||||
|
x = torch.rand(N, T, input_size).requires_grad_()
|
||||||
|
states = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_y, self_states = self_gru(x, states)
|
||||||
|
|
||||||
|
torch_y, torch_states = torch_gru(x_clone, torch.stack(states))
|
||||||
|
|
||||||
|
assert_allclose(self_y, torch_y, atol=1e-6)
|
||||||
|
|
||||||
|
self_states = torch.stack(self_states)
|
||||||
|
|
||||||
|
assert_allclose(self_states, torch_states, atol=1e-6)
|
||||||
|
|
||||||
|
s = self_y.reshape(-1)
|
||||||
|
t = torch_y.reshape(-1)
|
||||||
|
|
||||||
|
s_sum = (s * torch.arange(s.numel())).sum()
|
||||||
|
t_sum = (t * torch.arange(t.numel())).sum()
|
||||||
|
s_state_sum = s_sum + self_states.sum()
|
||||||
|
t_state_sum = t_sum + torch_states.sum()
|
||||||
|
|
||||||
|
s_state_sum.backward()
|
||||||
|
t_state_sum.backward()
|
||||||
|
assert_allclose(x.grad, x_clone.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lstm():
|
||||||
|
test_layernorm_lstm_cell_jit()
|
||||||
|
test_layernorm_lstm_cell_constructor()
|
||||||
|
test_layernorm_lstm_cell_forward()
|
||||||
|
#
|
||||||
|
test_layernorm_lstm_layer_jit()
|
||||||
|
test_layernorm_lstm_layer_forward()
|
||||||
|
#
|
||||||
|
test_layernorm_lstm_jit()
|
||||||
|
test_layernorm_lstm_forward()
|
||||||
|
|
||||||
|
|
||||||
|
def test_gru():
|
||||||
|
test_layernorm_gru_cell_jit()
|
||||||
|
test_layernorm_gru_cell_constructor()
|
||||||
|
test_layernorm_gru_cell_forward()
|
||||||
|
#
|
||||||
|
test_layernorm_gru_layer_jit()
|
||||||
|
test_layernorm_gru_layer_forward()
|
||||||
|
#
|
||||||
|
test_layernorm_gru_jit()
|
||||||
|
test_layernorm_gru_forward()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_lstm()
|
||||||
|
test_gru()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
torch.manual_seed(20211202)
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user