mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add GPU tests.
This commit is contained in:
parent
2c7547e1b7
commit
3d38f7bd31
@ -27,27 +27,46 @@ from transducer.rnn import (
|
||||
)
|
||||
|
||||
|
||||
def get_devices():
|
||||
devices = [torch.device("cpu")]
|
||||
if torch.cuda.is_available():
|
||||
devices.append(torch.device("cuda", 0))
|
||||
return devices
|
||||
|
||||
|
||||
def assert_allclose(a: torch.Tensor, b: torch.Tensor, **kwargs):
|
||||
assert torch.allclose(a, b, **kwargs), f"{(a - b).abs().max()}, {a.numel()}"
|
||||
|
||||
|
||||
def test_layernorm_lstm_cell_jit():
|
||||
def test_layernorm_lstm_cell_jit(device="cpu"):
|
||||
input_size = 10
|
||||
hidden_size = 20
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
cell = LayerNormLSTMCell(
|
||||
input_size=input_size, hidden_size=hidden_size, bias=bias
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
bias=bias,
|
||||
device=device,
|
||||
)
|
||||
|
||||
torch.jit.script(cell)
|
||||
|
||||
|
||||
def test_layernorm_lstm_cell_constructor():
|
||||
def test_layernorm_lstm_cell_constructor(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
|
||||
self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity)
|
||||
torch_cell = nn.LSTMCell(input_size, hidden_size)
|
||||
self_cell = LayerNormLSTMCell(
|
||||
input_size,
|
||||
hidden_size,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch_cell = nn.LSTMCell(
|
||||
input_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for name, param in self_cell.named_parameters():
|
||||
assert param.shape == getattr(torch_cell, name).shape
|
||||
@ -55,32 +74,46 @@ def test_layernorm_lstm_cell_constructor():
|
||||
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
|
||||
|
||||
|
||||
def test_layernorm_lstm_cell_with_projection_jit():
|
||||
def test_layernorm_lstm_cell_with_projection_jit(device="cpu"):
|
||||
input_size = 10
|
||||
hidden_size = 20
|
||||
proj_size = 5
|
||||
self_cell = LayerNormLSTMCell(input_size, hidden_size, proj_size=proj_size)
|
||||
self_cell = LayerNormLSTMCell(
|
||||
input_size,
|
||||
hidden_size,
|
||||
proj_size=proj_size,
|
||||
device=device,
|
||||
)
|
||||
torch.jit.script(self_cell)
|
||||
|
||||
|
||||
def test_layernorm_lstm_cell_forward():
|
||||
def test_layernorm_lstm_cell_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
|
||||
self_cell = LayerNormLSTMCell(
|
||||
input_size, hidden_size, bias=bias, ln=nn.Identity
|
||||
input_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch_cell = nn.LSTMCell(
|
||||
input_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
device=device,
|
||||
)
|
||||
torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
|
||||
with torch.no_grad():
|
||||
for name, torch_param in torch_cell.named_parameters():
|
||||
self_param = getattr(self_cell, name)
|
||||
torch_param.copy_(self_param)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
x = torch.rand(N, input_size).requires_grad_()
|
||||
h = torch.rand(N, hidden_size)
|
||||
c = torch.rand(N, hidden_size)
|
||||
x = torch.rand(N, input_size, device=device).requires_grad_()
|
||||
h = torch.rand(N, hidden_size, device=device)
|
||||
c = torch.rand(N, hidden_size, device=device)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
@ -88,17 +121,21 @@ def test_layernorm_lstm_cell_forward():
|
||||
torch_h, torch_c = torch_cell(x_clone, (h, c))
|
||||
|
||||
assert_allclose(self_h, torch_h)
|
||||
assert_allclose(self_c, torch_c)
|
||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
||||
|
||||
self_hc = self_h * self_c
|
||||
torch_hc = torch_h * torch_c
|
||||
(self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward()
|
||||
(torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward()
|
||||
(
|
||||
self_hc.reshape(-1) * torch.arange(self_hc.numel(), device=device)
|
||||
).sum().backward()
|
||||
(
|
||||
torch_hc.reshape(-1) * torch.arange(torch_hc.numel(), device=device)
|
||||
).sum().backward()
|
||||
|
||||
assert_allclose(x.grad, x_clone.grad)
|
||||
assert_allclose(x.grad, x_clone.grad, atol=1e-3)
|
||||
|
||||
|
||||
def test_layernorm_lstm_cell_with_projection_forward():
|
||||
def test_layernorm_lstm_cell_with_projection_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
@ -110,6 +147,7 @@ def test_layernorm_lstm_cell_with_projection_forward():
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
proj_size=proj_size,
|
||||
device=device,
|
||||
)
|
||||
torch_cell = nn.LSTM(
|
||||
input_size,
|
||||
@ -117,15 +155,16 @@ def test_layernorm_lstm_cell_with_projection_forward():
|
||||
bias=bias,
|
||||
proj_size=proj_size,
|
||||
batch_first=True,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for name, self_param in self_cell.named_parameters():
|
||||
getattr(torch_cell, f"{name}_l0").copy_(self_param)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
x = torch.rand(N, input_size).requires_grad_()
|
||||
h = torch.rand(N, proj_size)
|
||||
c = torch.rand(N, hidden_size)
|
||||
x = torch.rand(N, input_size, device=device).requires_grad_()
|
||||
h = torch.rand(N, proj_size, device=device)
|
||||
c = torch.rand(N, hidden_size, device=device)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
@ -138,22 +177,26 @@ def test_layernorm_lstm_cell_with_projection_forward():
|
||||
torch_c = torch_c.squeeze(0)
|
||||
|
||||
assert_allclose(self_h, torch_h)
|
||||
assert_allclose(self_c, torch_c)
|
||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
||||
|
||||
(self_h.sum() * self_c.sum()).backward()
|
||||
(torch_h.sum() * torch_c.sum()).backward()
|
||||
|
||||
assert_allclose(x.grad, x_clone.grad)
|
||||
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
|
||||
|
||||
|
||||
def test_layernorm_lstm_layer_jit():
|
||||
def test_layernorm_lstm_layer_jit(device="cpu"):
|
||||
input_size = 10
|
||||
hidden_size = 20
|
||||
layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size)
|
||||
layer = LayerNormLSTMLayer(
|
||||
input_size,
|
||||
hidden_size=hidden_size,
|
||||
device=device,
|
||||
)
|
||||
torch.jit.script(layer)
|
||||
|
||||
|
||||
def test_layernorm_lstm_layer_with_project_jit():
|
||||
def test_layernorm_lstm_layer_with_project_jit(device="cpu"):
|
||||
input_size = 10
|
||||
hidden_size = 20
|
||||
proj_size = 5
|
||||
@ -161,11 +204,12 @@ def test_layernorm_lstm_layer_with_project_jit():
|
||||
input_size,
|
||||
hidden_size=hidden_size,
|
||||
proj_size=proj_size,
|
||||
device=device,
|
||||
)
|
||||
torch.jit.script(layer)
|
||||
|
||||
|
||||
def test_layernorm_lstm_layer_with_projection_forward():
|
||||
def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
@ -177,14 +221,15 @@ def test_layernorm_lstm_layer_with_projection_forward():
|
||||
bias=bias,
|
||||
proj_size=proj_size,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
T = torch.randint(low=2, high=100, size=(1,))
|
||||
|
||||
x = torch.rand(N, T, input_size).requires_grad_()
|
||||
h = torch.rand(N, proj_size)
|
||||
c = torch.rand(N, hidden_size)
|
||||
x = torch.rand(N, T, input_size, device=device).requires_grad_()
|
||||
h = torch.rand(N, proj_size, device=device)
|
||||
c = torch.rand(N, hidden_size, device=device)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
@ -199,6 +244,7 @@ def test_layernorm_lstm_layer_with_projection_forward():
|
||||
batch_first=True,
|
||||
dropout=0,
|
||||
bidirectional=False,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for name, self_param in self_layer.cell.named_parameters():
|
||||
@ -207,25 +253,17 @@ def test_layernorm_lstm_layer_with_projection_forward():
|
||||
torch_y, (torch_h, torch_c) = torch_layer(
|
||||
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
||||
)
|
||||
assert_allclose(self_y, torch_y)
|
||||
assert_allclose(self_h, torch_h)
|
||||
assert_allclose(self_c, torch_c)
|
||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
||||
assert_allclose(self_h, torch_h, atol=1e-6)
|
||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
||||
|
||||
self_hc = self_h * self_c.sum()
|
||||
torch_hc = torch_h * torch_c.sum()
|
||||
self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum()
|
||||
torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum()
|
||||
|
||||
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
||||
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
||||
|
||||
(self_hc_sum * self_y_sum).backward()
|
||||
(torch_hc_sum * torch_y_sum).backward()
|
||||
self_y.sum().backward()
|
||||
torch_y.sum().backward()
|
||||
|
||||
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
|
||||
|
||||
|
||||
def test_layernorm_lstm_layer_forward():
|
||||
def test_layernorm_lstm_layer_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
@ -234,14 +272,15 @@ def test_layernorm_lstm_layer_forward():
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
T = torch.randint(low=2, high=100, size=(1,))
|
||||
|
||||
x = torch.rand(N, T, input_size).requires_grad_()
|
||||
h = torch.rand(N, hidden_size)
|
||||
c = torch.rand(N, hidden_size)
|
||||
x = torch.rand(N, T, input_size, device=device).requires_grad_()
|
||||
h = torch.rand(N, hidden_size, device=device)
|
||||
c = torch.rand(N, hidden_size, device=device)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
@ -255,6 +294,7 @@ def test_layernorm_lstm_layer_forward():
|
||||
batch_first=True,
|
||||
dropout=0,
|
||||
bidirectional=False,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for name, self_param in self_layer.cell.named_parameters():
|
||||
@ -263,25 +303,33 @@ def test_layernorm_lstm_layer_forward():
|
||||
torch_y, (torch_h, torch_c) = torch_layer(
|
||||
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
||||
)
|
||||
assert_allclose(self_y, torch_y)
|
||||
assert_allclose(self_h, torch_h)
|
||||
assert_allclose(self_c, torch_c)
|
||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
||||
assert_allclose(self_h, torch_h, atol=1e-6)
|
||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
||||
|
||||
self_hc = self_h * self_c
|
||||
torch_hc = torch_h * torch_c
|
||||
self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum()
|
||||
torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum()
|
||||
self_hc_sum = (
|
||||
self_hc.reshape(-1) * torch.arange(self_hc.numel(), device=device)
|
||||
).sum()
|
||||
torch_hc_sum = (
|
||||
torch_hc.reshape(-1) * torch.arange(torch_hc.numel(), device=device)
|
||||
).sum()
|
||||
|
||||
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
||||
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
||||
self_y_sum = (
|
||||
self_y.reshape(-1) * torch.arange(self_y.numel(), device=device)
|
||||
).sum()
|
||||
torch_y_sum = (
|
||||
torch_y.reshape(-1) * torch.arange(torch_y.numel(), device=device)
|
||||
).sum()
|
||||
|
||||
(self_hc_sum * self_y_sum).backward()
|
||||
(torch_hc_sum * torch_y_sum).backward()
|
||||
(self_hc_sum + self_y_sum).backward()
|
||||
(torch_hc_sum + torch_y_sum).backward()
|
||||
|
||||
assert_allclose(x.grad, x_clone.grad, atol=1e-5)
|
||||
assert_allclose(x.grad, x_clone.grad, atol=0.1)
|
||||
|
||||
|
||||
def test_layernorm_lstm_jit():
|
||||
def test_layernorm_lstm_jit(device="cpu"):
|
||||
input_size = 2
|
||||
hidden_size = 3
|
||||
num_layers = 4
|
||||
@ -293,11 +341,12 @@ def test_layernorm_lstm_jit():
|
||||
num_layers=num_layers,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch.jit.script(lstm)
|
||||
|
||||
|
||||
def test_layernorm_lstm_with_projection_jit():
|
||||
def test_layernorm_lstm_with_projection_jit(device="cpu"):
|
||||
input_size = 2
|
||||
hidden_size = 5
|
||||
proj_size = 3
|
||||
@ -311,11 +360,12 @@ def test_layernorm_lstm_with_projection_jit():
|
||||
bias=bias,
|
||||
proj_size=proj_size,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch.jit.script(lstm)
|
||||
|
||||
|
||||
def test_layernorm_lstm_forward():
|
||||
def test_layernorm_lstm_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
@ -327,6 +377,7 @@ def test_layernorm_lstm_forward():
|
||||
num_layers=num_layers,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch_lstm = nn.LSTM(
|
||||
input_size=input_size,
|
||||
@ -335,6 +386,7 @@ def test_layernorm_lstm_forward():
|
||||
bias=bias,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
device=device,
|
||||
)
|
||||
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
||||
with torch.no_grad():
|
||||
@ -347,9 +399,9 @@ def test_layernorm_lstm_forward():
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
T = torch.randint(low=2, high=100, size=(1,))
|
||||
|
||||
x = torch.rand(N, T, input_size).requires_grad_()
|
||||
hs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||
cs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||
x = torch.rand(N, T, input_size, device=device).requires_grad_()
|
||||
hs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
|
||||
cs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
|
||||
states = list(zip(hs, cs))
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
@ -371,17 +423,17 @@ def test_layernorm_lstm_forward():
|
||||
s = self_y.reshape(-1)
|
||||
t = torch_y.reshape(-1)
|
||||
|
||||
s_sum = (s * torch.arange(s.numel())).sum()
|
||||
t_sum = (t * torch.arange(t.numel())).sum()
|
||||
shc_sum = s_sum * self_h.sum() * self_c.sum()
|
||||
thc_sum = t_sum * torch_h.sum() * torch_c.sum()
|
||||
s_sum = (s * torch.arange(s.numel(), device=device)).sum()
|
||||
t_sum = (t * torch.arange(t.numel(), device=device)).sum()
|
||||
shc_sum = s_sum + self_h.sum() + self_c.sum()
|
||||
thc_sum = t_sum + torch_h.sum() + torch_c.sum()
|
||||
|
||||
shc_sum.backward()
|
||||
thc_sum.backward()
|
||||
assert_allclose(x.grad, x_clone.grad)
|
||||
|
||||
|
||||
def test_layernorm_lstm_with_projection_forward():
|
||||
def test_layernorm_lstm_with_projection_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=10, high=100, size=(1,)).item()
|
||||
proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item()
|
||||
@ -395,6 +447,7 @@ def test_layernorm_lstm_with_projection_forward():
|
||||
bias=bias,
|
||||
proj_size=proj_size,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch_lstm = nn.LSTM(
|
||||
input_size=input_size,
|
||||
@ -404,6 +457,7 @@ def test_layernorm_lstm_with_projection_forward():
|
||||
proj_size=proj_size,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
device=device,
|
||||
)
|
||||
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
||||
with torch.no_grad():
|
||||
@ -416,9 +470,9 @@ def test_layernorm_lstm_with_projection_forward():
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
T = torch.randint(low=2, high=100, size=(1,))
|
||||
|
||||
x = torch.rand(N, T, input_size).requires_grad_()
|
||||
hs = [torch.rand(N, proj_size) for _ in range(num_layers)]
|
||||
cs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||
x = torch.rand(N, T, input_size, device=device).requires_grad_()
|
||||
hs = [torch.rand(N, proj_size, device=device) for _ in range(num_layers)]
|
||||
cs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)]
|
||||
states = list(zip(hs, cs))
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
@ -429,43 +483,55 @@ def test_layernorm_lstm_with_projection_forward():
|
||||
c = torch.stack(cs)
|
||||
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
|
||||
|
||||
assert_allclose(self_y, torch_y)
|
||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
||||
|
||||
self_h = torch.stack([s[0] for s in self_states])
|
||||
self_c = torch.stack([s[1] for s in self_states])
|
||||
|
||||
assert_allclose(self_h, torch_h)
|
||||
assert_allclose(self_c, torch_c)
|
||||
assert_allclose(self_h, torch_h, atol=1e-6)
|
||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
||||
|
||||
s = self_y.reshape(-1)
|
||||
t = torch_y.reshape(-1)
|
||||
|
||||
s_sum = (s * torch.arange(s.numel())).sum()
|
||||
t_sum = (t * torch.arange(t.numel())).sum()
|
||||
shc_sum = s_sum * self_h.sum() * self_c.sum()
|
||||
thc_sum = t_sum * torch_h.sum() * torch_c.sum()
|
||||
s_sum = (s * torch.arange(s.numel(), device=device)).sum()
|
||||
t_sum = (t * torch.arange(t.numel(), device=device)).sum()
|
||||
shc_sum = s_sum + self_h.sum() + self_c.sum()
|
||||
thc_sum = t_sum + torch_h.sum() + torch_c.sum()
|
||||
|
||||
shc_sum.backward()
|
||||
thc_sum.backward()
|
||||
assert_allclose(x.grad, x_clone.grad)
|
||||
assert_allclose(x.grad, x_clone.grad, atol=1e-6)
|
||||
|
||||
|
||||
def test_layernorm_gru_cell_jit():
|
||||
def test_layernorm_gru_cell_jit(device="cpu"):
|
||||
input_size = 10
|
||||
hidden_size = 20
|
||||
cell = LayerNormGRUCell(
|
||||
input_size=input_size, hidden_size=hidden_size, bias=True
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
bias=True,
|
||||
device=device,
|
||||
)
|
||||
|
||||
torch.jit.script(cell)
|
||||
|
||||
|
||||
def test_layernorm_gru_cell_constructor():
|
||||
def test_layernorm_gru_cell_constructor(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
|
||||
self_cell = LayerNormGRUCell(input_size, hidden_size, ln=nn.Identity)
|
||||
torch_cell = nn.GRUCell(input_size, hidden_size)
|
||||
self_cell = LayerNormGRUCell(
|
||||
input_size,
|
||||
hidden_size,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch_cell = nn.GRUCell(
|
||||
input_size,
|
||||
hidden_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
for name, param in self_cell.named_parameters():
|
||||
assert param.shape == getattr(torch_cell, name).shape
|
||||
@ -473,23 +539,32 @@ def test_layernorm_gru_cell_constructor():
|
||||
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
|
||||
|
||||
|
||||
def test_layernorm_gru_cell_forward():
|
||||
def test_layernorm_gru_cell_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
|
||||
self_cell = LayerNormGRUCell(
|
||||
input_size, hidden_size, bias=bias, ln=nn.Identity
|
||||
input_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch_cell = nn.GRUCell(
|
||||
input_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
device=device,
|
||||
)
|
||||
torch_cell = nn.GRUCell(input_size, hidden_size, bias=bias)
|
||||
with torch.no_grad():
|
||||
for name, torch_param in torch_cell.named_parameters():
|
||||
self_param = getattr(self_cell, name)
|
||||
torch_param.copy_(self_param)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
x = torch.rand(N, input_size).requires_grad_()
|
||||
h = torch.rand(N, hidden_size)
|
||||
x = torch.rand(N, input_size, device=device).requires_grad_()
|
||||
h = torch.rand(N, hidden_size, device=device)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
@ -498,20 +573,28 @@ def test_layernorm_gru_cell_forward():
|
||||
|
||||
assert_allclose(self_h, torch_h, atol=1e-5)
|
||||
|
||||
(self_h.reshape(-1) * torch.arange(self_h.numel())).sum().backward()
|
||||
(torch_h.reshape(-1) * torch.arange(torch_h.numel())).sum().backward()
|
||||
(
|
||||
self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)
|
||||
).sum().backward()
|
||||
(
|
||||
torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
|
||||
).sum().backward()
|
||||
|
||||
assert_allclose(x.grad, x_clone.grad, atol=1e-4)
|
||||
|
||||
|
||||
def test_layernorm_gru_layer_jit():
|
||||
def test_layernorm_gru_layer_jit(device="cpu"):
|
||||
input_size = 10
|
||||
hidden_size = 20
|
||||
layer = LayerNormGRULayer(input_size, hidden_size=hidden_size)
|
||||
layer = LayerNormGRULayer(
|
||||
input_size,
|
||||
hidden_size=hidden_size,
|
||||
device=device,
|
||||
)
|
||||
torch.jit.script(layer)
|
||||
|
||||
|
||||
def test_layernorm_gru_layer_forward():
|
||||
def test_layernorm_gru_layer_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
@ -520,13 +603,14 @@ def test_layernorm_gru_layer_forward():
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
T = torch.randint(low=2, high=100, size=(1,))
|
||||
|
||||
x = torch.rand(N, T, input_size).requires_grad_()
|
||||
h = torch.rand(N, hidden_size)
|
||||
x = torch.rand(N, T, input_size, device=device).requires_grad_()
|
||||
h = torch.rand(N, hidden_size, device=device)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
@ -540,6 +624,7 @@ def test_layernorm_gru_layer_forward():
|
||||
batch_first=True,
|
||||
dropout=0,
|
||||
bidirectional=False,
|
||||
device=device,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for name, self_param in self_layer.cell.named_parameters():
|
||||
@ -549,8 +634,12 @@ def test_layernorm_gru_layer_forward():
|
||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
||||
assert_allclose(self_h, torch_h, atol=1e-6)
|
||||
|
||||
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
||||
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
||||
self_y_sum = (
|
||||
self_y.reshape(-1) * torch.arange(self_y.numel(), device=device)
|
||||
).sum()
|
||||
torch_y_sum = (
|
||||
torch_y.reshape(-1) * torch.arange(torch_y.numel(), device=device)
|
||||
).sum()
|
||||
|
||||
self_y_sum.backward()
|
||||
torch_y_sum.backward()
|
||||
@ -558,7 +647,7 @@ def test_layernorm_gru_layer_forward():
|
||||
assert_allclose(x.grad, x_clone.grad, atol=0.1)
|
||||
|
||||
|
||||
def test_layernorm_gru_jit():
|
||||
def test_layernorm_gru_jit(device="cpu"):
|
||||
input_size = 2
|
||||
hidden_size = 3
|
||||
num_layers = 4
|
||||
@ -570,11 +659,12 @@ def test_layernorm_gru_jit():
|
||||
num_layers=num_layers,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch.jit.script(gru)
|
||||
|
||||
|
||||
def test_layernorm_gru_forward():
|
||||
def test_layernorm_gru_forward(device="cpu"):
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
@ -586,6 +676,7 @@ def test_layernorm_gru_forward():
|
||||
num_layers=num_layers,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
device=device,
|
||||
)
|
||||
torch_gru = nn.GRU(
|
||||
input_size=input_size,
|
||||
@ -594,6 +685,7 @@ def test_layernorm_gru_forward():
|
||||
bias=bias,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
device=device,
|
||||
)
|
||||
assert len(self_gru.state_dict()) == len(torch_gru.state_dict())
|
||||
with torch.no_grad():
|
||||
@ -606,8 +698,10 @@ def test_layernorm_gru_forward():
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
T = torch.randint(low=2, high=100, size=(1,))
|
||||
|
||||
x = torch.rand(N, T, input_size).requires_grad_()
|
||||
states = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||
x = torch.rand(N, T, input_size, device=device).requires_grad_()
|
||||
states = [
|
||||
torch.rand(N, hidden_size, device=device) for _ in range(num_layers)
|
||||
]
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
@ -624,49 +718,51 @@ def test_layernorm_gru_forward():
|
||||
s = self_y.reshape(-1)
|
||||
t = torch_y.reshape(-1)
|
||||
|
||||
s_sum = (s * torch.arange(s.numel())).sum()
|
||||
t_sum = (t * torch.arange(t.numel())).sum()
|
||||
s_sum = (s * torch.arange(s.numel(), device=device)).sum()
|
||||
t_sum = (t * torch.arange(t.numel(), device=device)).sum()
|
||||
s_state_sum = s_sum + self_states.sum()
|
||||
t_state_sum = t_sum + torch_states.sum()
|
||||
|
||||
s_state_sum.backward()
|
||||
t_state_sum.backward()
|
||||
assert_allclose(x.grad, x_clone.grad, atol=1e-6)
|
||||
assert_allclose(x.grad, x_clone.grad, atol=1e-4)
|
||||
|
||||
|
||||
def test_lstm():
|
||||
test_layernorm_lstm_cell_jit()
|
||||
test_layernorm_lstm_cell_constructor()
|
||||
test_layernorm_lstm_cell_with_projection_jit()
|
||||
test_layernorm_lstm_cell_forward()
|
||||
test_layernorm_lstm_cell_with_projection_forward()
|
||||
def _test_lstm(device):
|
||||
test_layernorm_lstm_cell_jit(device)
|
||||
test_layernorm_lstm_cell_constructor(device)
|
||||
test_layernorm_lstm_cell_with_projection_jit(device)
|
||||
test_layernorm_lstm_cell_forward(device)
|
||||
test_layernorm_lstm_cell_with_projection_forward(device)
|
||||
#
|
||||
test_layernorm_lstm_layer_jit()
|
||||
test_layernorm_lstm_layer_with_project_jit()
|
||||
test_layernorm_lstm_layer_forward()
|
||||
test_layernorm_lstm_layer_with_projection_forward()
|
||||
test_layernorm_lstm_layer_jit(device)
|
||||
test_layernorm_lstm_layer_with_project_jit(device)
|
||||
test_layernorm_lstm_layer_forward(device)
|
||||
test_layernorm_lstm_layer_with_projection_forward(device)
|
||||
|
||||
test_layernorm_lstm_jit()
|
||||
test_layernorm_lstm_with_projection_jit()
|
||||
test_layernorm_lstm_forward()
|
||||
test_layernorm_lstm_with_projection_forward()
|
||||
test_layernorm_lstm_jit(device)
|
||||
test_layernorm_lstm_with_projection_jit(device)
|
||||
test_layernorm_lstm_forward(device)
|
||||
test_layernorm_lstm_with_projection_forward(device)
|
||||
|
||||
|
||||
def test_gru():
|
||||
test_layernorm_gru_cell_jit()
|
||||
test_layernorm_gru_cell_constructor()
|
||||
test_layernorm_gru_cell_forward()
|
||||
def _test_gru(device):
|
||||
test_layernorm_gru_cell_jit(device)
|
||||
test_layernorm_gru_cell_constructor(device)
|
||||
test_layernorm_gru_cell_forward(device)
|
||||
#
|
||||
test_layernorm_gru_layer_jit()
|
||||
test_layernorm_gru_layer_forward()
|
||||
test_layernorm_gru_layer_jit(device)
|
||||
test_layernorm_gru_layer_forward(device)
|
||||
#
|
||||
test_layernorm_gru_jit()
|
||||
test_layernorm_gru_forward()
|
||||
test_layernorm_gru_jit(device)
|
||||
test_layernorm_gru_forward(device)
|
||||
|
||||
|
||||
def main():
|
||||
test_lstm()
|
||||
test_gru()
|
||||
for device in get_devices():
|
||||
print("device", device)
|
||||
_test_lstm(device)
|
||||
_test_gru(device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user