From 2c7547e1b7948714cb0600a4ed3de0a87c1d1fb9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 3 Dec 2021 16:47:40 +0800 Subject: [PATCH] Add projection support to LayerNormLSTMCell. --- egs/librispeech/ASR/transducer/rnn.py | 66 ++++-- egs/librispeech/ASR/transducer/test_rnn.py | 227 ++++++++++++++++++++- 2 files changed, 278 insertions(+), 15 deletions(-) diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py index 5afb4e907..062a27645 100644 --- a/egs/librispeech/ASR/transducer/rnn.py +++ b/egs/librispeech/ASR/transducer/rnn.py @@ -30,7 +30,6 @@ import torch.nn as nn import torch.nn.functional as F -# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf class LayerNormLSTMCell(nn.Module): """This class places a `nn.LayerNorm` after the output of each gate (right before the activation). @@ -60,6 +59,7 @@ class LayerNormLSTMCell(nn.Module): hidden_size: int, bias: bool = True, ln: nn.Module = nn.LayerNorm, + proj_size: int = 0, device=None, dtype=None, ): @@ -70,7 +70,9 @@ class LayerNormLSTMCell(nn.Module): be of shape (batch_size, input_size). hidden_size: The number of features in the hidden state `h` and `c`. - Both `h` and `c` are of shape (batch_size, hidden_size). + Both `h` and `c` are of shape (batch_size, hidden_size) when + proj_size is 0. If proj_size is not zero, the shape of `h` + is (batch_size, proj_size). bias: If ``False``, then the cell does not use bias weights `bias_ih` and `bias_hh`. @@ -78,19 +80,38 @@ class LayerNormLSTMCell(nn.Module): Defaults to `nn.LayerNorm`. The output of all gates are processed by `ln`. We pass it as an argument so that we can replace it with `nn.Identity` at the testing time. + proj_size: + If not zero, it applies an affine transform to the output. In this + case, the shape of `h` is (batch_size, proj_size). + See https://arxiv.org/pdf/1402.1128.pdf """ super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.input_size = input_size self.hidden_size = hidden_size self.bias = bias + self.proj_size = proj_size + + if proj_size < 0: + raise ValueError( + f"proj_size {proj_size} should be a positive integer " + "or zero to disable projections" + ) + + if proj_size >= hidden_size: + raise ValueError( + f"proj_size {proj_size} has to be smaller " + f"than hidden_size {hidden_size}" + ) + + real_hidden_size = proj_size if proj_size > 0 else hidden_size self.weight_ih = nn.Parameter( torch.empty((4 * hidden_size, input_size), **factory_kwargs) ) self.weight_hh = nn.Parameter( - torch.empty((4 * hidden_size, hidden_size), **factory_kwargs) + torch.empty((4 * hidden_size, real_hidden_size), **factory_kwargs) ) if bias: @@ -104,6 +125,13 @@ class LayerNormLSTMCell(nn.Module): self.register_parameter("bias_ih", None) self.register_parameter("bias_hh", None) + if proj_size > 0: + self.weight_hr = nn.Parameter( + torch.empty((proj_size, hidden_size), **factory_kwargs) + ) + else: + self.register_parameter("weight_hr", None) + self.layernorm_i = ln(hidden_size) self.layernorm_f = ln(hidden_size) self.layernorm_cx = ln(hidden_size) @@ -123,12 +151,15 @@ class LayerNormLSTMCell(nn.Module): A 2-D tensor of shape (batch_size, input_size). state: If not ``None``, it contains the hidden state (h, c) for each - element in the batch. Both are of shape (batch_size, hidden_size). + element in the batch. Both are of shape (batch_size, hidden_size) + if proj_size is 0. If proj_size is not zero, the shape of `h` is + (batch_size, proj_size). If ``None``, it uses zeros for `h` and `c`. Returns: Return two tensors: - - `next_h`: It is of shape (batch_size, hidden_size) containing the - next hidden state for each element in the batch. + - `next_h`: It is of shape (batch_size, hidden_size) if proj_size + is 0, else (batch_size, proj_size), containing the next hidden + state for each element in the batch. - `next_c`: It is of shape (batch_size, hidden_size) containing the next cell state for each element in the batch. """ @@ -162,6 +193,9 @@ class LayerNormLSTMCell(nn.Module): cy = self.layernorm_cy(cy) hy = out_gate * torch.tanh(cy) + if self.weight_hr is not None: + hy = torch.matmul(hy, self.weight_hr.t()) + return hy, cy def extra_repr(self) -> str: @@ -172,8 +206,9 @@ class LayerNormLSTMCell(nn.Module): def reset_parameters(self) -> None: stdv = 1.0 / math.sqrt(self.hidden_size) - for weight in self.parameters(): - nn.init.uniform_(weight, -stdv, stdv) + for name, weight in self.named_parameters(): + if "layernorm" not in name: + nn.init.uniform_(weight, -stdv, stdv) class LayerNormLSTMLayer(nn.Module): @@ -199,6 +234,7 @@ class LayerNormLSTMLayer(nn.Module): hidden_size: int, bias: bool = True, ln: nn.Module = nn.LayerNorm, + proj_size: int = 0, device=None, dtype=None, ): @@ -211,6 +247,7 @@ class LayerNormLSTMLayer(nn.Module): hidden_size=hidden_size, bias=bias, ln=ln, + proj_size=proj_size, device=device, dtype=dtype, ) @@ -228,13 +265,14 @@ class LayerNormLSTMLayer(nn.Module): We use `batch_first=True` here. state: If not ``None``, it contains the hidden state (h, c) of this layer. - Both are of shape (batch_size, hidden_size). + Both are of shape (batch_size, hidden_size) if proj_size is 0. + If proj_size is not 0, the shape of `h` is (batch_size, proj_size). Note: We did not annotate `state` with `Optional[Tuple[...]]` since torchscript will complain. Return: - output, a tensor of shape (batch_size, seq_len, hidden_size) - - (next_h, next_c) containing the hidden state of this layer + - (next_h, next_c) containing the next hidden state """ inputs = input.unbind(1) outputs = torch.jit.annotate(List[torch.Tensor], []) @@ -270,6 +308,7 @@ class LayerNormLSTM(nn.Module): hidden_size: int, num_layers: int, bias: bool = True, + proj_size: int = 0, ln: nn.Module = nn.LayerNorm, device=None, dtype=None, @@ -283,6 +322,7 @@ class LayerNormLSTM(nn.Module): hidden_size=hidden_size, bias=bias, ln=ln, + proj_size=proj_size, device=device, dtype=dtype, ) @@ -293,7 +333,7 @@ class LayerNormLSTM(nn.Module): for i in range(1, num_layers): layers.append( LayerNormLSTMLayer( - input_size=hidden_size, + input_size=proj_size if proj_size > 0 else hidden_size, **factory_kwargs, ) ) @@ -313,7 +353,9 @@ class LayerNormLSTM(nn.Module): We use `batch_first=True` here. states: One state per layer. Each entry contains the hidden state (h, c) - for a layer. Both are of shape (batch_size, hidden_size). + for a layer. Both are of shape (batch_size, hidden_size) if + proj_size is 0. If proj_size is not 0, the shape of `h` is + (batch_size, proj_size). Returns: Return a tuple containing: diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index 7df7c7b78..47eebd588 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -55,6 +55,14 @@ def test_layernorm_lstm_cell_constructor(): assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) +def test_layernorm_lstm_cell_with_projection_jit(): + input_size = 10 + hidden_size = 20 + proj_size = 5 + self_cell = LayerNormLSTMCell(input_size, hidden_size, proj_size=proj_size) + torch.jit.script(self_cell) + + def test_layernorm_lstm_cell_forward(): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() @@ -90,6 +98,54 @@ def test_layernorm_lstm_cell_forward(): assert_allclose(x.grad, x_clone.grad) +def test_layernorm_lstm_cell_with_projection_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=10, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() + + self_cell = LayerNormLSTMCell( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + proj_size=proj_size, + ) + torch_cell = nn.LSTM( + input_size, + hidden_size, + bias=bias, + proj_size=proj_size, + batch_first=True, + ) + with torch.no_grad(): + for name, self_param in self_cell.named_parameters(): + getattr(torch_cell, f"{name}_l0").copy_(self_param) + + N = torch.randint(low=2, high=100, size=(1,)) + x = torch.rand(N, input_size).requires_grad_() + h = torch.rand(N, proj_size) + c = torch.rand(N, hidden_size) + + x_clone = x.detach().clone().requires_grad_() + + self_h, self_c = self_cell(x.clone(), (h, c)) + _, (torch_h, torch_c) = torch_cell( + x_clone.unsqueeze(1), (h.unsqueeze(0), c.unsqueeze(0)) + ) + + torch_h = torch_h.squeeze(0) + torch_c = torch_c.squeeze(0) + + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + (self_h.sum() * self_c.sum()).backward() + (torch_h.sum() * torch_c.sum()).backward() + + assert_allclose(x.grad, x_clone.grad) + + def test_layernorm_lstm_layer_jit(): input_size = 10 hidden_size = 20 @@ -97,6 +153,78 @@ def test_layernorm_lstm_layer_jit(): torch.jit.script(layer) +def test_layernorm_lstm_layer_with_project_jit(): + input_size = 10 + hidden_size = 20 + proj_size = 5 + layer = LayerNormLSTMLayer( + input_size, + hidden_size=hidden_size, + proj_size=proj_size, + ) + torch.jit.script(layer) + + +def test_layernorm_lstm_layer_with_projection_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=10, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() + + self_layer = LayerNormLSTMLayer( + input_size, + hidden_size, + bias=bias, + proj_size=proj_size, + ln=nn.Identity, + ) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size).requires_grad_() + h = torch.rand(N, proj_size) + c = torch.rand(N, hidden_size) + + x_clone = x.detach().clone().requires_grad_() + + self_y, (self_h, self_c) = self_layer(x, (h, c)) + + torch_layer = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=bias, + proj_size=proj_size, + batch_first=True, + dropout=0, + bidirectional=False, + ) + with torch.no_grad(): + for name, self_param in self_layer.cell.named_parameters(): + getattr(torch_layer, f"{name}_l0").copy_(self_param) + + torch_y, (torch_h, torch_c) = torch_layer( + x_clone, (h.unsqueeze(0), c.unsqueeze(0)) + ) + assert_allclose(self_y, torch_y) + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + self_hc = self_h * self_c.sum() + torch_hc = torch_h * torch_c.sum() + self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum() + torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum() + + self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() + torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() + + (self_hc_sum * self_y_sum).backward() + (torch_hc_sum * torch_y_sum).backward() + + assert_allclose(x.grad, x_clone.grad, atol=1e-5) + + def test_layernorm_lstm_layer_forward(): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() @@ -169,6 +297,24 @@ def test_layernorm_lstm_jit(): torch.jit.script(lstm) +def test_layernorm_lstm_with_projection_jit(): + input_size = 2 + hidden_size = 5 + proj_size = 3 + num_layers = 4 + bias = True + + lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + proj_size=proj_size, + ln=nn.Identity, + ) + torch.jit.script(lstm) + + def test_layernorm_lstm_forward(): input_size = torch.randint(low=2, high=100, size=(1,)).item() hidden_size = torch.randint(low=2, high=100, size=(1,)).item() @@ -235,6 +381,75 @@ def test_layernorm_lstm_forward(): assert_allclose(x.grad, x_clone.grad) +def test_layernorm_lstm_with_projection_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=10, high=100, size=(1,)).item() + proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + proj_size=proj_size, + ln=nn.Identity, + ) + torch_lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + proj_size=proj_size, + batch_first=True, + bidirectional=False, + ) + assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict()) + with torch.no_grad(): + for name, param in self_lstm.named_parameters(): + # name has the form layers.0.cell.weight_hh + parts = name.split(".") + layer_num = parts[1] + getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size).requires_grad_() + hs = [torch.rand(N, proj_size) for _ in range(num_layers)] + cs = [torch.rand(N, hidden_size) for _ in range(num_layers)] + states = list(zip(hs, cs)) + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_states = self_lstm(x, states) + + h = torch.stack(hs) + c = torch.stack(cs) + torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c)) + + assert_allclose(self_y, torch_y) + + self_h = torch.stack([s[0] for s in self_states]) + self_c = torch.stack([s[1] for s in self_states]) + + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + s = self_y.reshape(-1) + t = torch_y.reshape(-1) + + s_sum = (s * torch.arange(s.numel())).sum() + t_sum = (t * torch.arange(t.numel())).sum() + shc_sum = s_sum * self_h.sum() * self_c.sum() + thc_sum = t_sum * torch_h.sum() * torch_c.sum() + + shc_sum.backward() + thc_sum.backward() + assert_allclose(x.grad, x_clone.grad) + + def test_layernorm_gru_cell_jit(): input_size = 10 hidden_size = 20 @@ -332,7 +547,7 @@ def test_layernorm_gru_layer_forward(): torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0)) assert_allclose(self_y, torch_y, atol=1e-6) - assert_allclose(self_h, torch_h) + assert_allclose(self_h, torch_h, atol=1e-6) self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() @@ -416,19 +631,25 @@ def test_layernorm_gru_forward(): s_state_sum.backward() t_state_sum.backward() - assert_allclose(x.grad, x_clone.grad) + assert_allclose(x.grad, x_clone.grad, atol=1e-6) def test_lstm(): test_layernorm_lstm_cell_jit() test_layernorm_lstm_cell_constructor() + test_layernorm_lstm_cell_with_projection_jit() test_layernorm_lstm_cell_forward() + test_layernorm_lstm_cell_with_projection_forward() # test_layernorm_lstm_layer_jit() + test_layernorm_lstm_layer_with_project_jit() test_layernorm_lstm_layer_forward() - # + test_layernorm_lstm_layer_with_projection_forward() + test_layernorm_lstm_jit() + test_layernorm_lstm_with_projection_jit() test_layernorm_lstm_forward() + test_layernorm_lstm_with_projection_forward() def test_gru():