2021-12-03 16:47:40 +08:00

675 lines
20 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from transducer.rnn import (
LayerNormGRU,
LayerNormGRUCell,
LayerNormGRULayer,
LayerNormLSTM,
LayerNormLSTMCell,
LayerNormLSTMLayer,
)
def assert_allclose(a: torch.Tensor, b: torch.Tensor, **kwargs):
assert torch.allclose(a, b, **kwargs), f"{(a - b).abs().max()}, {a.numel()}"
def test_layernorm_lstm_cell_jit():
input_size = 10
hidden_size = 20
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
cell = LayerNormLSTMCell(
input_size=input_size, hidden_size=hidden_size, bias=bias
)
torch.jit.script(cell)
def test_layernorm_lstm_cell_constructor():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity)
torch_cell = nn.LSTMCell(input_size, hidden_size)
for name, param in self_cell.named_parameters():
assert param.shape == getattr(torch_cell, name).shape
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
def test_layernorm_lstm_cell_with_projection_jit():
input_size = 10
hidden_size = 20
proj_size = 5
self_cell = LayerNormLSTMCell(input_size, hidden_size, proj_size=proj_size)
torch.jit.script(self_cell)
def test_layernorm_lstm_cell_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_cell = LayerNormLSTMCell(
input_size, hidden_size, bias=bias, ln=nn.Identity
)
torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
with torch.no_grad():
for name, torch_param in torch_cell.named_parameters():
self_param = getattr(self_cell, name)
torch_param.copy_(self_param)
N = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, input_size).requires_grad_()
h = torch.rand(N, hidden_size)
c = torch.rand(N, hidden_size)
x_clone = x.detach().clone().requires_grad_()
self_h, self_c = self_cell(x.clone(), (h, c))
torch_h, torch_c = torch_cell(x_clone, (h, c))
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
self_hc = self_h * self_c
torch_hc = torch_h * torch_c
(self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward()
(torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward()
assert_allclose(x.grad, x_clone.grad)
def test_layernorm_lstm_cell_with_projection_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
self_cell = LayerNormLSTMCell(
input_size,
hidden_size,
bias=bias,
ln=nn.Identity,
proj_size=proj_size,
)
torch_cell = nn.LSTM(
input_size,
hidden_size,
bias=bias,
proj_size=proj_size,
batch_first=True,
)
with torch.no_grad():
for name, self_param in self_cell.named_parameters():
getattr(torch_cell, f"{name}_l0").copy_(self_param)
N = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, input_size).requires_grad_()
h = torch.rand(N, proj_size)
c = torch.rand(N, hidden_size)
x_clone = x.detach().clone().requires_grad_()
self_h, self_c = self_cell(x.clone(), (h, c))
_, (torch_h, torch_c) = torch_cell(
x_clone.unsqueeze(1), (h.unsqueeze(0), c.unsqueeze(0))
)
torch_h = torch_h.squeeze(0)
torch_c = torch_c.squeeze(0)
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
(self_h.sum() * self_c.sum()).backward()
(torch_h.sum() * torch_c.sum()).backward()
assert_allclose(x.grad, x_clone.grad)
def test_layernorm_lstm_layer_jit():
input_size = 10
hidden_size = 20
layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size)
torch.jit.script(layer)
def test_layernorm_lstm_layer_with_project_jit():
input_size = 10
hidden_size = 20
proj_size = 5
layer = LayerNormLSTMLayer(
input_size,
hidden_size=hidden_size,
proj_size=proj_size,
)
torch.jit.script(layer)
def test_layernorm_lstm_layer_with_projection_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
self_layer = LayerNormLSTMLayer(
input_size,
hidden_size,
bias=bias,
proj_size=proj_size,
ln=nn.Identity,
)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size).requires_grad_()
h = torch.rand(N, proj_size)
c = torch.rand(N, hidden_size)
x_clone = x.detach().clone().requires_grad_()
self_y, (self_h, self_c) = self_layer(x, (h, c))
torch_layer = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
proj_size=proj_size,
batch_first=True,
dropout=0,
bidirectional=False,
)
with torch.no_grad():
for name, self_param in self_layer.cell.named_parameters():
getattr(torch_layer, f"{name}_l0").copy_(self_param)
torch_y, (torch_h, torch_c) = torch_layer(
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
)
assert_allclose(self_y, torch_y)
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
self_hc = self_h * self_c.sum()
torch_hc = torch_h * torch_c.sum()
self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum()
torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum()
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
(self_hc_sum * self_y_sum).backward()
(torch_hc_sum * torch_y_sum).backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
def test_layernorm_lstm_layer_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_layer = LayerNormLSTMLayer(
input_size,
hidden_size,
bias=bias,
ln=nn.Identity,
)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size).requires_grad_()
h = torch.rand(N, hidden_size)
c = torch.rand(N, hidden_size)
x_clone = x.detach().clone().requires_grad_()
self_y, (self_h, self_c) = self_layer(x, (h, c))
torch_layer = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
batch_first=True,
dropout=0,
bidirectional=False,
)
with torch.no_grad():
for name, self_param in self_layer.cell.named_parameters():
getattr(torch_layer, f"{name}_l0").copy_(self_param)
torch_y, (torch_h, torch_c) = torch_layer(
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
)
assert_allclose(self_y, torch_y)
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
self_hc = self_h * self_c
torch_hc = torch_h * torch_c
self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum()
torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum()
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
(self_hc_sum * self_y_sum).backward()
(torch_hc_sum * torch_y_sum).backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
def test_layernorm_lstm_jit():
input_size = 2
hidden_size = 3
num_layers = 4
bias = True
lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
)
torch.jit.script(lstm)
def test_layernorm_lstm_with_projection_jit():
input_size = 2
hidden_size = 5
proj_size = 3
num_layers = 4
bias = True
lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
proj_size=proj_size,
ln=nn.Identity,
)
torch.jit.script(lstm)
def test_layernorm_lstm_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
)
torch_lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
bidirectional=False,
)
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
with torch.no_grad():
for name, param in self_lstm.named_parameters():
# name has the form layers.0.cell.weight_hh
parts = name.split(".")
layer_num = parts[1]
getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size).requires_grad_()
hs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
cs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
states = list(zip(hs, cs))
x_clone = x.detach().clone().requires_grad_()
self_y, self_states = self_lstm(x, states)
h = torch.stack(hs)
c = torch.stack(cs)
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
assert_allclose(self_y, torch_y)
self_h = torch.stack([s[0] for s in self_states])
self_c = torch.stack([s[1] for s in self_states])
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
s = self_y.reshape(-1)
t = torch_y.reshape(-1)
s_sum = (s * torch.arange(s.numel())).sum()
t_sum = (t * torch.arange(t.numel())).sum()
shc_sum = s_sum * self_h.sum() * self_c.sum()
thc_sum = t_sum * torch_h.sum() * torch_c.sum()
shc_sum.backward()
thc_sum.backward()
assert_allclose(x.grad, x_clone.grad)
def test_layernorm_lstm_with_projection_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_lstm = LayerNormLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
proj_size=proj_size,
ln=nn.Identity,
)
torch_lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
proj_size=proj_size,
batch_first=True,
bidirectional=False,
)
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
with torch.no_grad():
for name, param in self_lstm.named_parameters():
# name has the form layers.0.cell.weight_hh
parts = name.split(".")
layer_num = parts[1]
getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size).requires_grad_()
hs = [torch.rand(N, proj_size) for _ in range(num_layers)]
cs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
states = list(zip(hs, cs))
x_clone = x.detach().clone().requires_grad_()
self_y, self_states = self_lstm(x, states)
h = torch.stack(hs)
c = torch.stack(cs)
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
assert_allclose(self_y, torch_y)
self_h = torch.stack([s[0] for s in self_states])
self_c = torch.stack([s[1] for s in self_states])
assert_allclose(self_h, torch_h)
assert_allclose(self_c, torch_c)
s = self_y.reshape(-1)
t = torch_y.reshape(-1)
s_sum = (s * torch.arange(s.numel())).sum()
t_sum = (t * torch.arange(t.numel())).sum()
shc_sum = s_sum * self_h.sum() * self_c.sum()
thc_sum = t_sum * torch_h.sum() * torch_c.sum()
shc_sum.backward()
thc_sum.backward()
assert_allclose(x.grad, x_clone.grad)
def test_layernorm_gru_cell_jit():
input_size = 10
hidden_size = 20
cell = LayerNormGRUCell(
input_size=input_size, hidden_size=hidden_size, bias=True
)
torch.jit.script(cell)
def test_layernorm_gru_cell_constructor():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
self_cell = LayerNormGRUCell(input_size, hidden_size, ln=nn.Identity)
torch_cell = nn.GRUCell(input_size, hidden_size)
for name, param in self_cell.named_parameters():
assert param.shape == getattr(torch_cell, name).shape
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
def test_layernorm_gru_cell_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_cell = LayerNormGRUCell(
input_size, hidden_size, bias=bias, ln=nn.Identity
)
torch_cell = nn.GRUCell(input_size, hidden_size, bias=bias)
with torch.no_grad():
for name, torch_param in torch_cell.named_parameters():
self_param = getattr(self_cell, name)
torch_param.copy_(self_param)
N = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, input_size).requires_grad_()
h = torch.rand(N, hidden_size)
x_clone = x.detach().clone().requires_grad_()
self_h = self_cell(x.clone(), h)
torch_h = torch_cell(x_clone, h)
assert_allclose(self_h, torch_h, atol=1e-5)
(self_h.reshape(-1) * torch.arange(self_h.numel())).sum().backward()
(torch_h.reshape(-1) * torch.arange(torch_h.numel())).sum().backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-4)
def test_layernorm_gru_layer_jit():
input_size = 10
hidden_size = 20
layer = LayerNormGRULayer(input_size, hidden_size=hidden_size)
torch.jit.script(layer)
def test_layernorm_gru_layer_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_layer = LayerNormGRULayer(
input_size,
hidden_size,
bias=bias,
ln=nn.Identity,
)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size).requires_grad_()
h = torch.rand(N, hidden_size)
x_clone = x.detach().clone().requires_grad_()
self_y, self_h = self_layer(x, h.clone())
torch_layer = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
batch_first=True,
dropout=0,
bidirectional=False,
)
with torch.no_grad():
for name, self_param in self_layer.cell.named_parameters():
getattr(torch_layer, f"{name}_l0").copy_(self_param)
torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0))
assert_allclose(self_y, torch_y, atol=1e-6)
assert_allclose(self_h, torch_h, atol=1e-6)
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
self_y_sum.backward()
torch_y_sum.backward()
assert_allclose(x.grad, x_clone.grad, atol=0.1)
def test_layernorm_gru_jit():
input_size = 2
hidden_size = 3
num_layers = 4
bias = True
gru = LayerNormGRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
)
torch.jit.script(gru)
def test_layernorm_gru_forward():
input_size = torch.randint(low=2, high=100, size=(1,)).item()
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
self_gru = LayerNormGRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
ln=nn.Identity,
)
torch_gru = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
bidirectional=False,
)
assert len(self_gru.state_dict()) == len(torch_gru.state_dict())
with torch.no_grad():
for name, param in self_gru.named_parameters():
# name has the form layers.0.cell.weight_hh
parts = name.split(".")
layer_num = parts[1]
getattr(torch_gru, f"{parts[-1]}_l{layer_num}").copy_(param)
N = torch.randint(low=2, high=100, size=(1,))
T = torch.randint(low=2, high=100, size=(1,))
x = torch.rand(N, T, input_size).requires_grad_()
states = [torch.rand(N, hidden_size) for _ in range(num_layers)]
x_clone = x.detach().clone().requires_grad_()
self_y, self_states = self_gru(x, states)
torch_y, torch_states = torch_gru(x_clone, torch.stack(states))
assert_allclose(self_y, torch_y, atol=1e-6)
self_states = torch.stack(self_states)
assert_allclose(self_states, torch_states, atol=1e-6)
s = self_y.reshape(-1)
t = torch_y.reshape(-1)
s_sum = (s * torch.arange(s.numel())).sum()
t_sum = (t * torch.arange(t.numel())).sum()
s_state_sum = s_sum + self_states.sum()
t_state_sum = t_sum + torch_states.sum()
s_state_sum.backward()
t_state_sum.backward()
assert_allclose(x.grad, x_clone.grad, atol=1e-6)
def test_lstm():
test_layernorm_lstm_cell_jit()
test_layernorm_lstm_cell_constructor()
test_layernorm_lstm_cell_with_projection_jit()
test_layernorm_lstm_cell_forward()
test_layernorm_lstm_cell_with_projection_forward()
#
test_layernorm_lstm_layer_jit()
test_layernorm_lstm_layer_with_project_jit()
test_layernorm_lstm_layer_forward()
test_layernorm_lstm_layer_with_projection_forward()
test_layernorm_lstm_jit()
test_layernorm_lstm_with_projection_jit()
test_layernorm_lstm_forward()
test_layernorm_lstm_with_projection_forward()
def test_gru():
test_layernorm_gru_cell_jit()
test_layernorm_gru_cell_constructor()
test_layernorm_gru_cell_forward()
#
test_layernorm_gru_layer_jit()
test_layernorm_gru_layer_forward()
#
test_layernorm_gru_jit()
test_layernorm_gru_forward()
def main():
test_lstm()
test_gru()
if __name__ == "__main__":
torch.manual_seed(20211202)
main()