diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index 47eebd588..da118d390 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -27,27 +27,46 @@ from transducer.rnn import ( ) +def get_devices(): + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda", 0)) + return devices + + def assert_allclose(a: torch.Tensor, b: torch.Tensor, **kwargs): assert torch.allclose(a, b, **kwargs), f"{(a - b).abs().max()}, {a.numel()}" -def test_layernorm_lstm_cell_jit(): +def test_layernorm_lstm_cell_jit(device="cpu"): input_size = 10 hidden_size = 20 bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 cell = LayerNormLSTMCell( - input_size=input_size, hidden_size=hidden_size, bias=bias + input_size=input_size, + hidden_size=hidden_size, + bias=bias, + device=device, ) torch.jit.script(cell) -def test_layernorm_lstm_cell_constructor(): +def test_layernorm_lstm_cell_constructor(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() - self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity) - torch_cell = nn.LSTMCell(input_size, hidden_size) + self_cell = LayerNormLSTMCell( + input_size, + hidden_size, + ln=nn.Identity, + device=device, + ) + torch_cell = nn.LSTMCell( + input_size, + hidden_size, + device=device, + ) for name, param in self_cell.named_parameters(): assert param.shape == getattr(torch_cell, name).shape @@ -55,32 +74,46 @@ def test_layernorm_lstm_cell_constructor(): assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) -def test_layernorm_lstm_cell_with_projection_jit(): +def test_layernorm_lstm_cell_with_projection_jit(device="cpu"): input_size = 10 hidden_size = 20 proj_size = 5 - self_cell = LayerNormLSTMCell(input_size, hidden_size, proj_size=proj_size) + self_cell = LayerNormLSTMCell( + input_size, + hidden_size, + proj_size=proj_size, + device=device, + ) torch.jit.script(self_cell) -def test_layernorm_lstm_cell_forward(): +def test_layernorm_lstm_cell_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 self_cell = LayerNormLSTMCell( - input_size, hidden_size, bias=bias, ln=nn.Identity + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + device=device, + ) + torch_cell = nn.LSTMCell( + input_size, + hidden_size, + bias=bias, + device=device, ) - torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) with torch.no_grad(): for name, torch_param in torch_cell.named_parameters(): self_param = getattr(self_cell, name) torch_param.copy_(self_param) N = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, input_size).requires_grad_() - h = torch.rand(N, hidden_size) - c = torch.rand(N, hidden_size) + x = torch.rand(N, input_size, device=device).requires_grad_() + h = torch.rand(N, hidden_size, device=device) + c = torch.rand(N, hidden_size, device=device) x_clone = x.detach().clone().requires_grad_() @@ -88,17 +121,21 @@ def test_layernorm_lstm_cell_forward(): torch_h, torch_c = torch_cell(x_clone, (h, c)) assert_allclose(self_h, torch_h) - assert_allclose(self_c, torch_c) + assert_allclose(self_c, torch_c, atol=1e-6) self_hc = self_h * self_c torch_hc = torch_h * torch_c - (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward() - (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward() + ( + self_hc.reshape(-1) * torch.arange(self_hc.numel(), device=device) + ).sum().backward() + ( + torch_hc.reshape(-1) * torch.arange(torch_hc.numel(), device=device) + ).sum().backward() - assert_allclose(x.grad, x_clone.grad) + assert_allclose(x.grad, x_clone.grad, atol=1e-3) -def test_layernorm_lstm_cell_with_projection_forward(): +def test_layernorm_lstm_cell_with_projection_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=10, high=100, size=(1,)).item() bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 @@ -110,6 +147,7 @@ def test_layernorm_lstm_cell_with_projection_forward(): bias=bias, ln=nn.Identity, proj_size=proj_size, + device=device, ) torch_cell = nn.LSTM( input_size, @@ -117,15 +155,16 @@ def test_layernorm_lstm_cell_with_projection_forward(): bias=bias, proj_size=proj_size, batch_first=True, + device=device, ) with torch.no_grad(): for name, self_param in self_cell.named_parameters(): getattr(torch_cell, f"{name}_l0").copy_(self_param) N = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, input_size).requires_grad_() - h = torch.rand(N, proj_size) - c = torch.rand(N, hidden_size) + x = torch.rand(N, input_size, device=device).requires_grad_() + h = torch.rand(N, proj_size, device=device) + c = torch.rand(N, hidden_size, device=device) x_clone = x.detach().clone().requires_grad_() @@ -138,22 +177,26 @@ def test_layernorm_lstm_cell_with_projection_forward(): torch_c = torch_c.squeeze(0) assert_allclose(self_h, torch_h) - assert_allclose(self_c, torch_c) + assert_allclose(self_c, torch_c, atol=1e-6) (self_h.sum() * self_c.sum()).backward() (torch_h.sum() * torch_c.sum()).backward() - assert_allclose(x.grad, x_clone.grad) + assert_allclose(x.grad, x_clone.grad, atol=1e-5) -def test_layernorm_lstm_layer_jit(): +def test_layernorm_lstm_layer_jit(device="cpu"): input_size = 10 hidden_size = 20 - layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size) + layer = LayerNormLSTMLayer( + input_size, + hidden_size=hidden_size, + device=device, + ) torch.jit.script(layer) -def test_layernorm_lstm_layer_with_project_jit(): +def test_layernorm_lstm_layer_with_project_jit(device="cpu"): input_size = 10 hidden_size = 20 proj_size = 5 @@ -161,11 +204,12 @@ def test_layernorm_lstm_layer_with_project_jit(): input_size, hidden_size=hidden_size, proj_size=proj_size, + device=device, ) torch.jit.script(layer) -def test_layernorm_lstm_layer_with_projection_forward(): +def test_layernorm_lstm_layer_with_projection_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=10, high=100, size=(1,)).item() bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 @@ -177,14 +221,15 @@ def test_layernorm_lstm_layer_with_projection_forward(): bias=bias, proj_size=proj_size, ln=nn.Identity, + device=device, ) N = torch.randint(low=2, high=100, size=(1,)) T = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, T, input_size).requires_grad_() - h = torch.rand(N, proj_size) - c = torch.rand(N, hidden_size) + x = torch.rand(N, T, input_size, device=device).requires_grad_() + h = torch.rand(N, proj_size, device=device) + c = torch.rand(N, hidden_size, device=device) x_clone = x.detach().clone().requires_grad_() @@ -199,6 +244,7 @@ def test_layernorm_lstm_layer_with_projection_forward(): batch_first=True, dropout=0, bidirectional=False, + device=device, ) with torch.no_grad(): for name, self_param in self_layer.cell.named_parameters(): @@ -207,25 +253,17 @@ def test_layernorm_lstm_layer_with_projection_forward(): torch_y, (torch_h, torch_c) = torch_layer( x_clone, (h.unsqueeze(0), c.unsqueeze(0)) ) - assert_allclose(self_y, torch_y) - assert_allclose(self_h, torch_h) - assert_allclose(self_c, torch_c) + assert_allclose(self_y, torch_y, atol=1e-6) + assert_allclose(self_h, torch_h, atol=1e-6) + assert_allclose(self_c, torch_c, atol=1e-6) - self_hc = self_h * self_c.sum() - torch_hc = torch_h * torch_c.sum() - self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum() - torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum() - - self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() - torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() - - (self_hc_sum * self_y_sum).backward() - (torch_hc_sum * torch_y_sum).backward() + self_y.sum().backward() + torch_y.sum().backward() assert_allclose(x.grad, x_clone.grad, atol=1e-5) -def test_layernorm_lstm_layer_forward(): +def test_layernorm_lstm_layer_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 @@ -234,14 +272,15 @@ def test_layernorm_lstm_layer_forward(): hidden_size, bias=bias, ln=nn.Identity, + device=device, ) N = torch.randint(low=2, high=100, size=(1,)) T = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, T, input_size).requires_grad_() - h = torch.rand(N, hidden_size) - c = torch.rand(N, hidden_size) + x = torch.rand(N, T, input_size, device=device).requires_grad_() + h = torch.rand(N, hidden_size, device=device) + c = torch.rand(N, hidden_size, device=device) x_clone = x.detach().clone().requires_grad_() @@ -255,6 +294,7 @@ def test_layernorm_lstm_layer_forward(): batch_first=True, dropout=0, bidirectional=False, + device=device, ) with torch.no_grad(): for name, self_param in self_layer.cell.named_parameters(): @@ -263,25 +303,33 @@ def test_layernorm_lstm_layer_forward(): torch_y, (torch_h, torch_c) = torch_layer( x_clone, (h.unsqueeze(0), c.unsqueeze(0)) ) - assert_allclose(self_y, torch_y) - assert_allclose(self_h, torch_h) - assert_allclose(self_c, torch_c) + assert_allclose(self_y, torch_y, atol=1e-6) + assert_allclose(self_h, torch_h, atol=1e-6) + assert_allclose(self_c, torch_c, atol=1e-6) self_hc = self_h * self_c torch_hc = torch_h * torch_c - self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum() - torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum() + self_hc_sum = ( + self_hc.reshape(-1) * torch.arange(self_hc.numel(), device=device) + ).sum() + torch_hc_sum = ( + torch_hc.reshape(-1) * torch.arange(torch_hc.numel(), device=device) + ).sum() - self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() - torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() + self_y_sum = ( + self_y.reshape(-1) * torch.arange(self_y.numel(), device=device) + ).sum() + torch_y_sum = ( + torch_y.reshape(-1) * torch.arange(torch_y.numel(), device=device) + ).sum() - (self_hc_sum * self_y_sum).backward() - (torch_hc_sum * torch_y_sum).backward() + (self_hc_sum + self_y_sum).backward() + (torch_hc_sum + torch_y_sum).backward() - assert_allclose(x.grad, x_clone.grad, atol=1e-5) + assert_allclose(x.grad, x_clone.grad, atol=0.1) -def test_layernorm_lstm_jit(): +def test_layernorm_lstm_jit(device="cpu"): input_size = 2 hidden_size = 3 num_layers = 4 @@ -293,11 +341,12 @@ def test_layernorm_lstm_jit(): num_layers=num_layers, bias=bias, ln=nn.Identity, + device=device, ) torch.jit.script(lstm) -def test_layernorm_lstm_with_projection_jit(): +def test_layernorm_lstm_with_projection_jit(device="cpu"): input_size = 2 hidden_size = 5 proj_size = 3 @@ -311,11 +360,12 @@ def test_layernorm_lstm_with_projection_jit(): bias=bias, proj_size=proj_size, ln=nn.Identity, + device=device, ) torch.jit.script(lstm) -def test_layernorm_lstm_forward(): +def test_layernorm_lstm_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() num_layers = torch.randint(low=2, high=100, size=(1,)).item() @@ -327,6 +377,7 @@ def test_layernorm_lstm_forward(): num_layers=num_layers, bias=bias, ln=nn.Identity, + device=device, ) torch_lstm = nn.LSTM( input_size=input_size, @@ -335,6 +386,7 @@ def test_layernorm_lstm_forward(): bias=bias, batch_first=True, bidirectional=False, + device=device, ) assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict()) with torch.no_grad(): @@ -347,9 +399,9 @@ def test_layernorm_lstm_forward(): N = torch.randint(low=2, high=100, size=(1,)) T = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, T, input_size).requires_grad_() - hs = [torch.rand(N, hidden_size) for _ in range(num_layers)] - cs = [torch.rand(N, hidden_size) for _ in range(num_layers)] + x = torch.rand(N, T, input_size, device=device).requires_grad_() + hs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)] + cs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)] states = list(zip(hs, cs)) x_clone = x.detach().clone().requires_grad_() @@ -371,17 +423,17 @@ def test_layernorm_lstm_forward(): s = self_y.reshape(-1) t = torch_y.reshape(-1) - s_sum = (s * torch.arange(s.numel())).sum() - t_sum = (t * torch.arange(t.numel())).sum() - shc_sum = s_sum * self_h.sum() * self_c.sum() - thc_sum = t_sum * torch_h.sum() * torch_c.sum() + s_sum = (s * torch.arange(s.numel(), device=device)).sum() + t_sum = (t * torch.arange(t.numel(), device=device)).sum() + shc_sum = s_sum + self_h.sum() + self_c.sum() + thc_sum = t_sum + torch_h.sum() + torch_c.sum() shc_sum.backward() thc_sum.backward() assert_allclose(x.grad, x_clone.grad) -def test_layernorm_lstm_with_projection_forward(): +def test_layernorm_lstm_with_projection_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=10, high=100, size=(1,)).item() proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() @@ -395,6 +447,7 @@ def test_layernorm_lstm_with_projection_forward(): bias=bias, proj_size=proj_size, ln=nn.Identity, + device=device, ) torch_lstm = nn.LSTM( input_size=input_size, @@ -404,6 +457,7 @@ def test_layernorm_lstm_with_projection_forward(): proj_size=proj_size, batch_first=True, bidirectional=False, + device=device, ) assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict()) with torch.no_grad(): @@ -416,9 +470,9 @@ def test_layernorm_lstm_with_projection_forward(): N = torch.randint(low=2, high=100, size=(1,)) T = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, T, input_size).requires_grad_() - hs = [torch.rand(N, proj_size) for _ in range(num_layers)] - cs = [torch.rand(N, hidden_size) for _ in range(num_layers)] + x = torch.rand(N, T, input_size, device=device).requires_grad_() + hs = [torch.rand(N, proj_size, device=device) for _ in range(num_layers)] + cs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)] states = list(zip(hs, cs)) x_clone = x.detach().clone().requires_grad_() @@ -429,43 +483,55 @@ def test_layernorm_lstm_with_projection_forward(): c = torch.stack(cs) torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c)) - assert_allclose(self_y, torch_y) + assert_allclose(self_y, torch_y, atol=1e-6) self_h = torch.stack([s[0] for s in self_states]) self_c = torch.stack([s[1] for s in self_states]) - assert_allclose(self_h, torch_h) - assert_allclose(self_c, torch_c) + assert_allclose(self_h, torch_h, atol=1e-6) + assert_allclose(self_c, torch_c, atol=1e-6) s = self_y.reshape(-1) t = torch_y.reshape(-1) - s_sum = (s * torch.arange(s.numel())).sum() - t_sum = (t * torch.arange(t.numel())).sum() - shc_sum = s_sum * self_h.sum() * self_c.sum() - thc_sum = t_sum * torch_h.sum() * torch_c.sum() + s_sum = (s * torch.arange(s.numel(), device=device)).sum() + t_sum = (t * torch.arange(t.numel(), device=device)).sum() + shc_sum = s_sum + self_h.sum() + self_c.sum() + thc_sum = t_sum + torch_h.sum() + torch_c.sum() shc_sum.backward() thc_sum.backward() - assert_allclose(x.grad, x_clone.grad) + assert_allclose(x.grad, x_clone.grad, atol=1e-6) -def test_layernorm_gru_cell_jit(): +def test_layernorm_gru_cell_jit(device="cpu"): input_size = 10 hidden_size = 20 cell = LayerNormGRUCell( - input_size=input_size, hidden_size=hidden_size, bias=True + input_size=input_size, + hidden_size=hidden_size, + bias=True, + device=device, ) torch.jit.script(cell) -def test_layernorm_gru_cell_constructor(): +def test_layernorm_gru_cell_constructor(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() - self_cell = LayerNormGRUCell(input_size, hidden_size, ln=nn.Identity) - torch_cell = nn.GRUCell(input_size, hidden_size) + self_cell = LayerNormGRUCell( + input_size, + hidden_size, + ln=nn.Identity, + device=device, + ) + torch_cell = nn.GRUCell( + input_size, + hidden_size, + device=device, + ) for name, param in self_cell.named_parameters(): assert param.shape == getattr(torch_cell, name).shape @@ -473,23 +539,32 @@ def test_layernorm_gru_cell_constructor(): assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) -def test_layernorm_gru_cell_forward(): +def test_layernorm_gru_cell_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 self_cell = LayerNormGRUCell( - input_size, hidden_size, bias=bias, ln=nn.Identity + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + device=device, + ) + torch_cell = nn.GRUCell( + input_size, + hidden_size, + bias=bias, + device=device, ) - torch_cell = nn.GRUCell(input_size, hidden_size, bias=bias) with torch.no_grad(): for name, torch_param in torch_cell.named_parameters(): self_param = getattr(self_cell, name) torch_param.copy_(self_param) N = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, input_size).requires_grad_() - h = torch.rand(N, hidden_size) + x = torch.rand(N, input_size, device=device).requires_grad_() + h = torch.rand(N, hidden_size, device=device) x_clone = x.detach().clone().requires_grad_() @@ -498,20 +573,28 @@ def test_layernorm_gru_cell_forward(): assert_allclose(self_h, torch_h, atol=1e-5) - (self_h.reshape(-1) * torch.arange(self_h.numel())).sum().backward() - (torch_h.reshape(-1) * torch.arange(torch_h.numel())).sum().backward() + ( + self_h.reshape(-1) * torch.arange(self_h.numel(), device=device) + ).sum().backward() + ( + torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device) + ).sum().backward() assert_allclose(x.grad, x_clone.grad, atol=1e-4) -def test_layernorm_gru_layer_jit(): +def test_layernorm_gru_layer_jit(device="cpu"): input_size = 10 hidden_size = 20 - layer = LayerNormGRULayer(input_size, hidden_size=hidden_size) + layer = LayerNormGRULayer( + input_size, + hidden_size=hidden_size, + device=device, + ) torch.jit.script(layer) -def test_layernorm_gru_layer_forward(): +def test_layernorm_gru_layer_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 @@ -520,13 +603,14 @@ def test_layernorm_gru_layer_forward(): hidden_size, bias=bias, ln=nn.Identity, + device=device, ) N = torch.randint(low=2, high=100, size=(1,)) T = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, T, input_size).requires_grad_() - h = torch.rand(N, hidden_size) + x = torch.rand(N, T, input_size, device=device).requires_grad_() + h = torch.rand(N, hidden_size, device=device) x_clone = x.detach().clone().requires_grad_() @@ -540,6 +624,7 @@ def test_layernorm_gru_layer_forward(): batch_first=True, dropout=0, bidirectional=False, + device=device, ) with torch.no_grad(): for name, self_param in self_layer.cell.named_parameters(): @@ -549,8 +634,12 @@ def test_layernorm_gru_layer_forward(): assert_allclose(self_y, torch_y, atol=1e-6) assert_allclose(self_h, torch_h, atol=1e-6) - self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() - torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() + self_y_sum = ( + self_y.reshape(-1) * torch.arange(self_y.numel(), device=device) + ).sum() + torch_y_sum = ( + torch_y.reshape(-1) * torch.arange(torch_y.numel(), device=device) + ).sum() self_y_sum.backward() torch_y_sum.backward() @@ -558,7 +647,7 @@ def test_layernorm_gru_layer_forward(): assert_allclose(x.grad, x_clone.grad, atol=0.1) -def test_layernorm_gru_jit(): +def test_layernorm_gru_jit(device="cpu"): input_size = 2 hidden_size = 3 num_layers = 4 @@ -570,11 +659,12 @@ def test_layernorm_gru_jit(): num_layers=num_layers, bias=bias, ln=nn.Identity, + device=device, ) torch.jit.script(gru) -def test_layernorm_gru_forward(): +def test_layernorm_gru_forward(device="cpu"): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() num_layers = torch.randint(low=2, high=100, size=(1,)).item() @@ -586,6 +676,7 @@ def test_layernorm_gru_forward(): num_layers=num_layers, bias=bias, ln=nn.Identity, + device=device, ) torch_gru = nn.GRU( input_size=input_size, @@ -594,6 +685,7 @@ def test_layernorm_gru_forward(): bias=bias, batch_first=True, bidirectional=False, + device=device, ) assert len(self_gru.state_dict()) == len(torch_gru.state_dict()) with torch.no_grad(): @@ -606,8 +698,10 @@ def test_layernorm_gru_forward(): N = torch.randint(low=2, high=100, size=(1,)) T = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, T, input_size).requires_grad_() - states = [torch.rand(N, hidden_size) for _ in range(num_layers)] + x = torch.rand(N, T, input_size, device=device).requires_grad_() + states = [ + torch.rand(N, hidden_size, device=device) for _ in range(num_layers) + ] x_clone = x.detach().clone().requires_grad_() @@ -624,49 +718,51 @@ def test_layernorm_gru_forward(): s = self_y.reshape(-1) t = torch_y.reshape(-1) - s_sum = (s * torch.arange(s.numel())).sum() - t_sum = (t * torch.arange(t.numel())).sum() + s_sum = (s * torch.arange(s.numel(), device=device)).sum() + t_sum = (t * torch.arange(t.numel(), device=device)).sum() s_state_sum = s_sum + self_states.sum() t_state_sum = t_sum + torch_states.sum() s_state_sum.backward() t_state_sum.backward() - assert_allclose(x.grad, x_clone.grad, atol=1e-6) + assert_allclose(x.grad, x_clone.grad, atol=1e-4) -def test_lstm(): - test_layernorm_lstm_cell_jit() - test_layernorm_lstm_cell_constructor() - test_layernorm_lstm_cell_with_projection_jit() - test_layernorm_lstm_cell_forward() - test_layernorm_lstm_cell_with_projection_forward() +def _test_lstm(device): + test_layernorm_lstm_cell_jit(device) + test_layernorm_lstm_cell_constructor(device) + test_layernorm_lstm_cell_with_projection_jit(device) + test_layernorm_lstm_cell_forward(device) + test_layernorm_lstm_cell_with_projection_forward(device) # - test_layernorm_lstm_layer_jit() - test_layernorm_lstm_layer_with_project_jit() - test_layernorm_lstm_layer_forward() - test_layernorm_lstm_layer_with_projection_forward() + test_layernorm_lstm_layer_jit(device) + test_layernorm_lstm_layer_with_project_jit(device) + test_layernorm_lstm_layer_forward(device) + test_layernorm_lstm_layer_with_projection_forward(device) - test_layernorm_lstm_jit() - test_layernorm_lstm_with_projection_jit() - test_layernorm_lstm_forward() - test_layernorm_lstm_with_projection_forward() + test_layernorm_lstm_jit(device) + test_layernorm_lstm_with_projection_jit(device) + test_layernorm_lstm_forward(device) + test_layernorm_lstm_with_projection_forward(device) -def test_gru(): - test_layernorm_gru_cell_jit() - test_layernorm_gru_cell_constructor() - test_layernorm_gru_cell_forward() +def _test_gru(device): + test_layernorm_gru_cell_jit(device) + test_layernorm_gru_cell_constructor(device) + test_layernorm_gru_cell_forward(device) # - test_layernorm_gru_layer_jit() - test_layernorm_gru_layer_forward() + test_layernorm_gru_layer_jit(device) + test_layernorm_gru_layer_forward(device) # - test_layernorm_gru_jit() - test_layernorm_gru_forward() + test_layernorm_gru_jit(device) + test_layernorm_gru_forward(device) def main(): - test_lstm() - test_gru() + for device in get_devices(): + print("device", device) + _test_lstm(device) + _test_gru(device) if __name__ == "__main__":