mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add projection support to LayerNormLSTMCell.
This commit is contained in:
parent
1d004ca966
commit
2c7547e1b7
@ -30,7 +30,6 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf
|
|
||||||
class LayerNormLSTMCell(nn.Module):
|
class LayerNormLSTMCell(nn.Module):
|
||||||
"""This class places a `nn.LayerNorm` after the output of
|
"""This class places a `nn.LayerNorm` after the output of
|
||||||
each gate (right before the activation).
|
each gate (right before the activation).
|
||||||
@ -60,6 +59,7 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: nn.Module = nn.LayerNorm,
|
||||||
|
proj_size: int = 0,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
@ -70,7 +70,9 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
be of shape (batch_size, input_size).
|
be of shape (batch_size, input_size).
|
||||||
hidden_size:
|
hidden_size:
|
||||||
The number of features in the hidden state `h` and `c`.
|
The number of features in the hidden state `h` and `c`.
|
||||||
Both `h` and `c` are of shape (batch_size, hidden_size).
|
Both `h` and `c` are of shape (batch_size, hidden_size) when
|
||||||
|
proj_size is 0. If proj_size is not zero, the shape of `h`
|
||||||
|
is (batch_size, proj_size).
|
||||||
bias:
|
bias:
|
||||||
If ``False``, then the cell does not use bias weights
|
If ``False``, then the cell does not use bias weights
|
||||||
`bias_ih` and `bias_hh`.
|
`bias_ih` and `bias_hh`.
|
||||||
@ -78,19 +80,38 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
Defaults to `nn.LayerNorm`. The output of all gates are processed
|
Defaults to `nn.LayerNorm`. The output of all gates are processed
|
||||||
by `ln`. We pass it as an argument so that we can replace it
|
by `ln`. We pass it as an argument so that we can replace it
|
||||||
with `nn.Identity` at the testing time.
|
with `nn.Identity` at the testing time.
|
||||||
|
proj_size:
|
||||||
|
If not zero, it applies an affine transform to the output. In this
|
||||||
|
case, the shape of `h` is (batch_size, proj_size).
|
||||||
|
See https://arxiv.org/pdf/1402.1128.pdf
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
self.proj_size = proj_size
|
||||||
|
|
||||||
|
if proj_size < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"proj_size {proj_size} should be a positive integer "
|
||||||
|
"or zero to disable projections"
|
||||||
|
)
|
||||||
|
|
||||||
|
if proj_size >= hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"proj_size {proj_size} has to be smaller "
|
||||||
|
f"than hidden_size {hidden_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
real_hidden_size = proj_size if proj_size > 0 else hidden_size
|
||||||
|
|
||||||
self.weight_ih = nn.Parameter(
|
self.weight_ih = nn.Parameter(
|
||||||
torch.empty((4 * hidden_size, input_size), **factory_kwargs)
|
torch.empty((4 * hidden_size, input_size), **factory_kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.weight_hh = nn.Parameter(
|
self.weight_hh = nn.Parameter(
|
||||||
torch.empty((4 * hidden_size, hidden_size), **factory_kwargs)
|
torch.empty((4 * hidden_size, real_hidden_size), **factory_kwargs)
|
||||||
)
|
)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
@ -104,6 +125,13 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
self.register_parameter("bias_ih", None)
|
self.register_parameter("bias_ih", None)
|
||||||
self.register_parameter("bias_hh", None)
|
self.register_parameter("bias_hh", None)
|
||||||
|
|
||||||
|
if proj_size > 0:
|
||||||
|
self.weight_hr = nn.Parameter(
|
||||||
|
torch.empty((proj_size, hidden_size), **factory_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("weight_hr", None)
|
||||||
|
|
||||||
self.layernorm_i = ln(hidden_size)
|
self.layernorm_i = ln(hidden_size)
|
||||||
self.layernorm_f = ln(hidden_size)
|
self.layernorm_f = ln(hidden_size)
|
||||||
self.layernorm_cx = ln(hidden_size)
|
self.layernorm_cx = ln(hidden_size)
|
||||||
@ -123,12 +151,15 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
A 2-D tensor of shape (batch_size, input_size).
|
A 2-D tensor of shape (batch_size, input_size).
|
||||||
state:
|
state:
|
||||||
If not ``None``, it contains the hidden state (h, c) for each
|
If not ``None``, it contains the hidden state (h, c) for each
|
||||||
element in the batch. Both are of shape (batch_size, hidden_size).
|
element in the batch. Both are of shape (batch_size, hidden_size)
|
||||||
|
if proj_size is 0. If proj_size is not zero, the shape of `h` is
|
||||||
|
(batch_size, proj_size).
|
||||||
If ``None``, it uses zeros for `h` and `c`.
|
If ``None``, it uses zeros for `h` and `c`.
|
||||||
Returns:
|
Returns:
|
||||||
Return two tensors:
|
Return two tensors:
|
||||||
- `next_h`: It is of shape (batch_size, hidden_size) containing the
|
- `next_h`: It is of shape (batch_size, hidden_size) if proj_size
|
||||||
next hidden state for each element in the batch.
|
is 0, else (batch_size, proj_size), containing the next hidden
|
||||||
|
state for each element in the batch.
|
||||||
- `next_c`: It is of shape (batch_size, hidden_size) containing the
|
- `next_c`: It is of shape (batch_size, hidden_size) containing the
|
||||||
next cell state for each element in the batch.
|
next cell state for each element in the batch.
|
||||||
"""
|
"""
|
||||||
@ -162,6 +193,9 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
cy = self.layernorm_cy(cy)
|
cy = self.layernorm_cy(cy)
|
||||||
hy = out_gate * torch.tanh(cy)
|
hy = out_gate * torch.tanh(cy)
|
||||||
|
|
||||||
|
if self.weight_hr is not None:
|
||||||
|
hy = torch.matmul(hy, self.weight_hr.t())
|
||||||
|
|
||||||
return hy, cy
|
return hy, cy
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
@ -172,8 +206,9 @@ class LayerNormLSTMCell(nn.Module):
|
|||||||
|
|
||||||
def reset_parameters(self) -> None:
|
def reset_parameters(self) -> None:
|
||||||
stdv = 1.0 / math.sqrt(self.hidden_size)
|
stdv = 1.0 / math.sqrt(self.hidden_size)
|
||||||
for weight in self.parameters():
|
for name, weight in self.named_parameters():
|
||||||
nn.init.uniform_(weight, -stdv, stdv)
|
if "layernorm" not in name:
|
||||||
|
nn.init.uniform_(weight, -stdv, stdv)
|
||||||
|
|
||||||
|
|
||||||
class LayerNormLSTMLayer(nn.Module):
|
class LayerNormLSTMLayer(nn.Module):
|
||||||
@ -199,6 +234,7 @@ class LayerNormLSTMLayer(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: nn.Module = nn.LayerNorm,
|
||||||
|
proj_size: int = 0,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
@ -211,6 +247,7 @@ class LayerNormLSTMLayer(nn.Module):
|
|||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
ln=ln,
|
ln=ln,
|
||||||
|
proj_size=proj_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
@ -228,13 +265,14 @@ class LayerNormLSTMLayer(nn.Module):
|
|||||||
We use `batch_first=True` here.
|
We use `batch_first=True` here.
|
||||||
state:
|
state:
|
||||||
If not ``None``, it contains the hidden state (h, c) of this layer.
|
If not ``None``, it contains the hidden state (h, c) of this layer.
|
||||||
Both are of shape (batch_size, hidden_size).
|
Both are of shape (batch_size, hidden_size) if proj_size is 0.
|
||||||
|
If proj_size is not 0, the shape of `h` is (batch_size, proj_size).
|
||||||
Note:
|
Note:
|
||||||
We did not annotate `state` with `Optional[Tuple[...]]` since
|
We did not annotate `state` with `Optional[Tuple[...]]` since
|
||||||
torchscript will complain.
|
torchscript will complain.
|
||||||
Return:
|
Return:
|
||||||
- output, a tensor of shape (batch_size, seq_len, hidden_size)
|
- output, a tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
- (next_h, next_c) containing the hidden state of this layer
|
- (next_h, next_c) containing the next hidden state
|
||||||
"""
|
"""
|
||||||
inputs = input.unbind(1)
|
inputs = input.unbind(1)
|
||||||
outputs = torch.jit.annotate(List[torch.Tensor], [])
|
outputs = torch.jit.annotate(List[torch.Tensor], [])
|
||||||
@ -270,6 +308,7 @@ class LayerNormLSTM(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
|
proj_size: int = 0,
|
||||||
ln: nn.Module = nn.LayerNorm,
|
ln: nn.Module = nn.LayerNorm,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -283,6 +322,7 @@ class LayerNormLSTM(nn.Module):
|
|||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
ln=ln,
|
ln=ln,
|
||||||
|
proj_size=proj_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
@ -293,7 +333,7 @@ class LayerNormLSTM(nn.Module):
|
|||||||
for i in range(1, num_layers):
|
for i in range(1, num_layers):
|
||||||
layers.append(
|
layers.append(
|
||||||
LayerNormLSTMLayer(
|
LayerNormLSTMLayer(
|
||||||
input_size=hidden_size,
|
input_size=proj_size if proj_size > 0 else hidden_size,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -313,7 +353,9 @@ class LayerNormLSTM(nn.Module):
|
|||||||
We use `batch_first=True` here.
|
We use `batch_first=True` here.
|
||||||
states:
|
states:
|
||||||
One state per layer. Each entry contains the hidden state (h, c)
|
One state per layer. Each entry contains the hidden state (h, c)
|
||||||
for a layer. Both are of shape (batch_size, hidden_size).
|
for a layer. Both are of shape (batch_size, hidden_size) if
|
||||||
|
proj_size is 0. If proj_size is not 0, the shape of `h` is
|
||||||
|
(batch_size, proj_size).
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing:
|
Return a tuple containing:
|
||||||
|
|
||||||
|
@ -55,6 +55,14 @@ def test_layernorm_lstm_cell_constructor():
|
|||||||
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
|
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_cell_with_projection_jit():
|
||||||
|
input_size = 10
|
||||||
|
hidden_size = 20
|
||||||
|
proj_size = 5
|
||||||
|
self_cell = LayerNormLSTMCell(input_size, hidden_size, proj_size=proj_size)
|
||||||
|
torch.jit.script(self_cell)
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_cell_forward():
|
def test_layernorm_lstm_cell_forward():
|
||||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
@ -90,6 +98,54 @@ def test_layernorm_lstm_cell_forward():
|
|||||||
assert_allclose(x.grad, x_clone.grad)
|
assert_allclose(x.grad, x_clone.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_cell_with_projection_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
|
||||||
|
|
||||||
|
self_cell = LayerNormLSTMCell(
|
||||||
|
input_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
ln=nn.Identity,
|
||||||
|
proj_size=proj_size,
|
||||||
|
)
|
||||||
|
torch_cell = nn.LSTM(
|
||||||
|
input_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
proj_size=proj_size,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, self_param in self_cell.named_parameters():
|
||||||
|
getattr(torch_cell, f"{name}_l0").copy_(self_param)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
x = torch.rand(N, input_size).requires_grad_()
|
||||||
|
h = torch.rand(N, proj_size)
|
||||||
|
c = torch.rand(N, hidden_size)
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_h, self_c = self_cell(x.clone(), (h, c))
|
||||||
|
_, (torch_h, torch_c) = torch_cell(
|
||||||
|
x_clone.unsqueeze(1), (h.unsqueeze(0), c.unsqueeze(0))
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_h = torch_h.squeeze(0)
|
||||||
|
torch_c = torch_c.squeeze(0)
|
||||||
|
|
||||||
|
assert_allclose(self_h, torch_h)
|
||||||
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
|
(self_h.sum() * self_c.sum()).backward()
|
||||||
|
(torch_h.sum() * torch_c.sum()).backward()
|
||||||
|
|
||||||
|
assert_allclose(x.grad, x_clone.grad)
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_layer_jit():
|
def test_layernorm_lstm_layer_jit():
|
||||||
input_size = 10
|
input_size = 10
|
||||||
hidden_size = 20
|
hidden_size = 20
|
||||||
@ -97,6 +153,78 @@ def test_layernorm_lstm_layer_jit():
|
|||||||
torch.jit.script(layer)
|
torch.jit.script(layer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_layer_with_project_jit():
|
||||||
|
input_size = 10
|
||||||
|
hidden_size = 20
|
||||||
|
proj_size = 5
|
||||||
|
layer = LayerNormLSTMLayer(
|
||||||
|
input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
proj_size=proj_size,
|
||||||
|
)
|
||||||
|
torch.jit.script(layer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_layer_with_projection_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
|
||||||
|
|
||||||
|
self_layer = LayerNormLSTMLayer(
|
||||||
|
input_size,
|
||||||
|
hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
proj_size=proj_size,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
T = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
|
||||||
|
x = torch.rand(N, T, input_size).requires_grad_()
|
||||||
|
h = torch.rand(N, proj_size)
|
||||||
|
c = torch.rand(N, hidden_size)
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_y, (self_h, self_c) = self_layer(x, (h, c))
|
||||||
|
|
||||||
|
torch_layer = nn.LSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=1,
|
||||||
|
bias=bias,
|
||||||
|
proj_size=proj_size,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=0,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, self_param in self_layer.cell.named_parameters():
|
||||||
|
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
||||||
|
|
||||||
|
torch_y, (torch_h, torch_c) = torch_layer(
|
||||||
|
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
||||||
|
)
|
||||||
|
assert_allclose(self_y, torch_y)
|
||||||
|
assert_allclose(self_h, torch_h)
|
||||||
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
|
self_hc = self_h * self_c.sum()
|
||||||
|
torch_hc = torch_h * torch_c.sum()
|
||||||
|
self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum()
|
||||||
|
torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum()
|
||||||
|
|
||||||
|
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
||||||
|
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
||||||
|
|
||||||
|
(self_hc_sum * self_y_sum).backward()
|
||||||
|
(torch_hc_sum * torch_y_sum).backward()
|
||||||
|
|
||||||
|
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_layer_forward():
|
def test_layernorm_lstm_layer_forward():
|
||||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
@ -169,6 +297,24 @@ def test_layernorm_lstm_jit():
|
|||||||
torch.jit.script(lstm)
|
torch.jit.script(lstm)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_with_projection_jit():
|
||||||
|
input_size = 2
|
||||||
|
hidden_size = 5
|
||||||
|
proj_size = 3
|
||||||
|
num_layers = 4
|
||||||
|
bias = True
|
||||||
|
|
||||||
|
lstm = LayerNormLSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
proj_size=proj_size,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
torch.jit.script(lstm)
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_forward():
|
def test_layernorm_lstm_forward():
|
||||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
@ -235,6 +381,75 @@ def test_layernorm_lstm_forward():
|
|||||||
assert_allclose(x.grad, x_clone.grad)
|
assert_allclose(x.grad, x_clone.grad)
|
||||||
|
|
||||||
|
|
||||||
|
def test_layernorm_lstm_with_projection_forward():
|
||||||
|
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
|
||||||
|
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
|
||||||
|
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
|
||||||
|
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||||
|
|
||||||
|
self_lstm = LayerNormLSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
proj_size=proj_size,
|
||||||
|
ln=nn.Identity,
|
||||||
|
)
|
||||||
|
torch_lstm = nn.LSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
bias=bias,
|
||||||
|
proj_size=proj_size,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=False,
|
||||||
|
)
|
||||||
|
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
||||||
|
with torch.no_grad():
|
||||||
|
for name, param in self_lstm.named_parameters():
|
||||||
|
# name has the form layers.0.cell.weight_hh
|
||||||
|
parts = name.split(".")
|
||||||
|
layer_num = parts[1]
|
||||||
|
getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param)
|
||||||
|
|
||||||
|
N = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
T = torch.randint(low=2, high=100, size=(1,))
|
||||||
|
|
||||||
|
x = torch.rand(N, T, input_size).requires_grad_()
|
||||||
|
hs = [torch.rand(N, proj_size) for _ in range(num_layers)]
|
||||||
|
cs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||||
|
states = list(zip(hs, cs))
|
||||||
|
|
||||||
|
x_clone = x.detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
self_y, self_states = self_lstm(x, states)
|
||||||
|
|
||||||
|
h = torch.stack(hs)
|
||||||
|
c = torch.stack(cs)
|
||||||
|
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
|
||||||
|
|
||||||
|
assert_allclose(self_y, torch_y)
|
||||||
|
|
||||||
|
self_h = torch.stack([s[0] for s in self_states])
|
||||||
|
self_c = torch.stack([s[1] for s in self_states])
|
||||||
|
|
||||||
|
assert_allclose(self_h, torch_h)
|
||||||
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
|
s = self_y.reshape(-1)
|
||||||
|
t = torch_y.reshape(-1)
|
||||||
|
|
||||||
|
s_sum = (s * torch.arange(s.numel())).sum()
|
||||||
|
t_sum = (t * torch.arange(t.numel())).sum()
|
||||||
|
shc_sum = s_sum * self_h.sum() * self_c.sum()
|
||||||
|
thc_sum = t_sum * torch_h.sum() * torch_c.sum()
|
||||||
|
|
||||||
|
shc_sum.backward()
|
||||||
|
thc_sum.backward()
|
||||||
|
assert_allclose(x.grad, x_clone.grad)
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_gru_cell_jit():
|
def test_layernorm_gru_cell_jit():
|
||||||
input_size = 10
|
input_size = 10
|
||||||
hidden_size = 20
|
hidden_size = 20
|
||||||
@ -332,7 +547,7 @@ def test_layernorm_gru_layer_forward():
|
|||||||
|
|
||||||
torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0))
|
torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0))
|
||||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
assert_allclose(self_y, torch_y, atol=1e-6)
|
||||||
assert_allclose(self_h, torch_h)
|
assert_allclose(self_h, torch_h, atol=1e-6)
|
||||||
|
|
||||||
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
||||||
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
||||||
@ -416,19 +631,25 @@ def test_layernorm_gru_forward():
|
|||||||
|
|
||||||
s_state_sum.backward()
|
s_state_sum.backward()
|
||||||
t_state_sum.backward()
|
t_state_sum.backward()
|
||||||
assert_allclose(x.grad, x_clone.grad)
|
assert_allclose(x.grad, x_clone.grad, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
def test_lstm():
|
def test_lstm():
|
||||||
test_layernorm_lstm_cell_jit()
|
test_layernorm_lstm_cell_jit()
|
||||||
test_layernorm_lstm_cell_constructor()
|
test_layernorm_lstm_cell_constructor()
|
||||||
|
test_layernorm_lstm_cell_with_projection_jit()
|
||||||
test_layernorm_lstm_cell_forward()
|
test_layernorm_lstm_cell_forward()
|
||||||
|
test_layernorm_lstm_cell_with_projection_forward()
|
||||||
#
|
#
|
||||||
test_layernorm_lstm_layer_jit()
|
test_layernorm_lstm_layer_jit()
|
||||||
|
test_layernorm_lstm_layer_with_project_jit()
|
||||||
test_layernorm_lstm_layer_forward()
|
test_layernorm_lstm_layer_forward()
|
||||||
#
|
test_layernorm_lstm_layer_with_projection_forward()
|
||||||
|
|
||||||
test_layernorm_lstm_jit()
|
test_layernorm_lstm_jit()
|
||||||
|
test_layernorm_lstm_with_projection_jit()
|
||||||
test_layernorm_lstm_forward()
|
test_layernorm_lstm_forward()
|
||||||
|
test_layernorm_lstm_with_projection_forward()
|
||||||
|
|
||||||
|
|
||||||
def test_gru():
|
def test_gru():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user