mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix test failures for torch 1.8.0
This commit is contained in:
parent
b3a5b04e13
commit
cafd06e909
@ -34,8 +34,10 @@ def get_devices():
|
|||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
||||||
def assert_allclose(a: torch.Tensor, b: torch.Tensor, **kwargs):
|
def assert_allclose(a: torch.Tensor, b: torch.Tensor, atol=1e-6, **kwargs):
|
||||||
assert torch.allclose(a, b, **kwargs), f"{(a - b).abs().max()}, {a.numel()}"
|
assert torch.allclose(
|
||||||
|
a, b, atol=atol, **kwargs
|
||||||
|
), f"{(a - b).abs().max()}, {a.numel()}"
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_lstm_cell_jit(device="cpu"):
|
def test_layernorm_lstm_cell_jit(device="cpu"):
|
||||||
@ -65,8 +67,7 @@ def test_layernorm_lstm_cell_constructor(device="cpu"):
|
|||||||
torch_cell = nn.LSTMCell(
|
torch_cell = nn.LSTMCell(
|
||||||
input_size,
|
input_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
|
|
||||||
for name, param in self_cell.named_parameters():
|
for name, param in self_cell.named_parameters():
|
||||||
assert param.shape == getattr(torch_cell, name).shape
|
assert param.shape == getattr(torch_cell, name).shape
|
||||||
@ -103,8 +104,7 @@ def test_layernorm_lstm_cell_forward(device="cpu"):
|
|||||||
input_size,
|
input_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, torch_param in torch_cell.named_parameters():
|
for name, torch_param in torch_cell.named_parameters():
|
||||||
self_param = getattr(self_cell, name)
|
self_param = getattr(self_cell, name)
|
||||||
@ -121,7 +121,7 @@ def test_layernorm_lstm_cell_forward(device="cpu"):
|
|||||||
torch_h, torch_c = torch_cell(x_clone, (h, c))
|
torch_h, torch_c = torch_cell(x_clone, (h, c))
|
||||||
|
|
||||||
assert_allclose(self_h, torch_h)
|
assert_allclose(self_h, torch_h)
|
||||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
self_hc = self_h * self_c
|
self_hc = self_h * self_c
|
||||||
torch_hc = torch_h * torch_c
|
torch_hc = torch_h * torch_c
|
||||||
@ -155,8 +155,7 @@ def test_layernorm_lstm_cell_with_projection_forward(device="cpu"):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
proj_size=proj_size,
|
proj_size=proj_size,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, self_param in self_cell.named_parameters():
|
for name, self_param in self_cell.named_parameters():
|
||||||
getattr(torch_cell, f"{name}_l0").copy_(self_param)
|
getattr(torch_cell, f"{name}_l0").copy_(self_param)
|
||||||
@ -177,7 +176,7 @@ def test_layernorm_lstm_cell_with_projection_forward(device="cpu"):
|
|||||||
torch_c = torch_c.squeeze(0)
|
torch_c = torch_c.squeeze(0)
|
||||||
|
|
||||||
assert_allclose(self_h, torch_h)
|
assert_allclose(self_h, torch_h)
|
||||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
(self_h.sum() * self_c.sum()).backward()
|
(self_h.sum() * self_c.sum()).backward()
|
||||||
(torch_h.sum() * torch_c.sum()).backward()
|
(torch_h.sum() * torch_c.sum()).backward()
|
||||||
@ -244,8 +243,7 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
|
|||||||
batch_first=True,
|
batch_first=True,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
bidirectional=False,
|
bidirectional=False,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, self_param in self_layer.cell.named_parameters():
|
for name, self_param in self_layer.cell.named_parameters():
|
||||||
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
||||||
@ -253,9 +251,9 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"):
|
|||||||
torch_y, (torch_h, torch_c) = torch_layer(
|
torch_y, (torch_h, torch_c) = torch_layer(
|
||||||
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
||||||
)
|
)
|
||||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
assert_allclose(self_y, torch_y)
|
||||||
assert_allclose(self_h, torch_h, atol=1e-6)
|
assert_allclose(self_h, torch_h)
|
||||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
self_y.sum().backward()
|
self_y.sum().backward()
|
||||||
torch_y.sum().backward()
|
torch_y.sum().backward()
|
||||||
@ -294,8 +292,7 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
|
|||||||
batch_first=True,
|
batch_first=True,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
bidirectional=False,
|
bidirectional=False,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, self_param in self_layer.cell.named_parameters():
|
for name, self_param in self_layer.cell.named_parameters():
|
||||||
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
||||||
@ -303,9 +300,9 @@ def test_layernorm_lstm_layer_forward(device="cpu"):
|
|||||||
torch_y, (torch_h, torch_c) = torch_layer(
|
torch_y, (torch_h, torch_c) = torch_layer(
|
||||||
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
||||||
)
|
)
|
||||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
assert_allclose(self_y, torch_y)
|
||||||
assert_allclose(self_h, torch_h, atol=1e-6)
|
assert_allclose(self_h, torch_h)
|
||||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
self_hc = self_h * self_c
|
self_hc = self_h * self_c
|
||||||
torch_hc = torch_h * torch_c
|
torch_hc = torch_h * torch_c
|
||||||
@ -386,8 +383,7 @@ def test_layernorm_lstm_forward(device="cpu"):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
bidirectional=False,
|
bidirectional=False,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, param in self_lstm.named_parameters():
|
for name, param in self_lstm.named_parameters():
|
||||||
@ -457,8 +453,7 @@ def test_layernorm_lstm_with_projection_forward(device="cpu"):
|
|||||||
proj_size=proj_size,
|
proj_size=proj_size,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
bidirectional=False,
|
bidirectional=False,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, param in self_lstm.named_parameters():
|
for name, param in self_lstm.named_parameters():
|
||||||
@ -483,13 +478,13 @@ def test_layernorm_lstm_with_projection_forward(device="cpu"):
|
|||||||
c = torch.stack(cs)
|
c = torch.stack(cs)
|
||||||
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
|
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
|
||||||
|
|
||||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
assert_allclose(self_y, torch_y)
|
||||||
|
|
||||||
self_h = torch.stack([s[0] for s in self_states])
|
self_h = torch.stack([s[0] for s in self_states])
|
||||||
self_c = torch.stack([s[1] for s in self_states])
|
self_c = torch.stack([s[1] for s in self_states])
|
||||||
|
|
||||||
assert_allclose(self_h, torch_h, atol=1e-6)
|
assert_allclose(self_h, torch_h)
|
||||||
assert_allclose(self_c, torch_c, atol=1e-6)
|
assert_allclose(self_c, torch_c)
|
||||||
|
|
||||||
s = self_y.reshape(-1)
|
s = self_y.reshape(-1)
|
||||||
t = torch_y.reshape(-1)
|
t = torch_y.reshape(-1)
|
||||||
@ -501,7 +496,7 @@ def test_layernorm_lstm_with_projection_forward(device="cpu"):
|
|||||||
|
|
||||||
shc_sum.backward()
|
shc_sum.backward()
|
||||||
thc_sum.backward()
|
thc_sum.backward()
|
||||||
assert_allclose(x.grad, x_clone.grad, atol=1e-6)
|
assert_allclose(x.grad, x_clone.grad)
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_gru_cell_jit(device="cpu"):
|
def test_layernorm_gru_cell_jit(device="cpu"):
|
||||||
@ -530,8 +525,7 @@ def test_layernorm_gru_cell_constructor(device="cpu"):
|
|||||||
torch_cell = nn.GRUCell(
|
torch_cell = nn.GRUCell(
|
||||||
input_size,
|
input_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
|
|
||||||
for name, param in self_cell.named_parameters():
|
for name, param in self_cell.named_parameters():
|
||||||
assert param.shape == getattr(torch_cell, name).shape
|
assert param.shape == getattr(torch_cell, name).shape
|
||||||
@ -555,8 +549,7 @@ def test_layernorm_gru_cell_forward(device="cpu"):
|
|||||||
input_size,
|
input_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, torch_param in torch_cell.named_parameters():
|
for name, torch_param in torch_cell.named_parameters():
|
||||||
self_param = getattr(self_cell, name)
|
self_param = getattr(self_cell, name)
|
||||||
@ -580,7 +573,7 @@ def test_layernorm_gru_cell_forward(device="cpu"):
|
|||||||
torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
|
torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device)
|
||||||
).sum().backward()
|
).sum().backward()
|
||||||
|
|
||||||
assert_allclose(x.grad, x_clone.grad, atol=1e-4)
|
assert_allclose(x.grad, x_clone.grad, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
def test_layernorm_gru_layer_jit(device="cpu"):
|
def test_layernorm_gru_layer_jit(device="cpu"):
|
||||||
@ -624,15 +617,14 @@ def test_layernorm_gru_layer_forward(device="cpu"):
|
|||||||
batch_first=True,
|
batch_first=True,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
bidirectional=False,
|
bidirectional=False,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, self_param in self_layer.cell.named_parameters():
|
for name, self_param in self_layer.cell.named_parameters():
|
||||||
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
||||||
|
|
||||||
torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0))
|
torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0))
|
||||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
assert_allclose(self_y, torch_y)
|
||||||
assert_allclose(self_h, torch_h, atol=1e-6)
|
assert_allclose(self_h, torch_h)
|
||||||
|
|
||||||
self_y_sum = (
|
self_y_sum = (
|
||||||
self_y.reshape(-1) * torch.arange(self_y.numel(), device=device)
|
self_y.reshape(-1) * torch.arange(self_y.numel(), device=device)
|
||||||
@ -685,8 +677,7 @@ def test_layernorm_gru_forward(device="cpu"):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
batch_first=True,
|
batch_first=True,
|
||||||
bidirectional=False,
|
bidirectional=False,
|
||||||
device=device,
|
).to(device)
|
||||||
)
|
|
||||||
assert len(self_gru.state_dict()) == len(torch_gru.state_dict())
|
assert len(self_gru.state_dict()) == len(torch_gru.state_dict())
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for name, param in self_gru.named_parameters():
|
for name, param in self_gru.named_parameters():
|
||||||
@ -709,11 +700,11 @@ def test_layernorm_gru_forward(device="cpu"):
|
|||||||
|
|
||||||
torch_y, torch_states = torch_gru(x_clone, torch.stack(states))
|
torch_y, torch_states = torch_gru(x_clone, torch.stack(states))
|
||||||
|
|
||||||
assert_allclose(self_y, torch_y, atol=1e-6)
|
assert_allclose(self_y, torch_y)
|
||||||
|
|
||||||
self_states = torch.stack(self_states)
|
self_states = torch.stack(self_states)
|
||||||
|
|
||||||
assert_allclose(self_states, torch_states, atol=1e-6)
|
assert_allclose(self_states, torch_states)
|
||||||
|
|
||||||
s = self_y.reshape(-1)
|
s = self_y.reshape(-1)
|
||||||
t = torch_y.reshape(-1)
|
t = torch_y.reshape(-1)
|
||||||
@ -758,6 +749,10 @@ def _test_gru(device):
|
|||||||
test_layernorm_gru_forward(device)
|
test_layernorm_gru_forward(device)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
for device in get_devices():
|
for device in get_devices():
|
||||||
print("device", device)
|
print("device", device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user