From 1d004ca966b37734e1bf7b908bf3c483b88d7b74 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 3 Dec 2021 14:59:19 +0800 Subject: [PATCH] Apply layer normalization to the output of each gate in GRU. --- egs/librispeech/ASR/rnnt/rnn.py | 284 -------- egs/librispeech/ASR/rnnt/test_rnn.py | 240 ------- .../ASR/{rnnt => transducer}/__init__.py | 0 egs/librispeech/ASR/transducer/rnn.py | 610 ++++++++++++++++++ egs/librispeech/ASR/transducer/test_rnn.py | 453 +++++++++++++ 5 files changed, 1063 insertions(+), 524 deletions(-) delete mode 100644 egs/librispeech/ASR/rnnt/rnn.py delete mode 100755 egs/librispeech/ASR/rnnt/test_rnn.py rename egs/librispeech/ASR/{rnnt => transducer}/__init__.py (100%) create mode 100644 egs/librispeech/ASR/transducer/rnn.py create mode 100755 egs/librispeech/ASR/transducer/test_rnn.py diff --git a/egs/librispeech/ASR/rnnt/rnn.py b/egs/librispeech/ASR/rnnt/rnn.py deleted file mode 100644 index 31849225b..000000000 --- a/egs/librispeech/ASR/rnnt/rnn.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Apply layer normalization to the output of each gate in LSTM/GRU. - -This file uses -https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py -as a reference. -""" - -import math -from typing import List, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf -class LayerNormLSTMCell(nn.Module): - """This class places a `nn.LayerNorm` after the output of - each gate (right before the activation). - - See the following paper for more details - - 'Improving RNN Transducer Modeling for End-to-End Speech Recognition' - https://arxiv.org/abs/1909.12415 - """ - - def __init__( - self, - input_size: int, - hidden_size: int, - bias: bool = True, - ln: nn.Module = nn.LayerNorm, - device=None, - dtype=None, - ): - """ - Args: - input_size: - The number of expected features in the input `x`. `x` should - be of shape (batch_size, input_size). - hidden_size: - The number of features in the hidden state `h` and `c`. - Both `h` and `c` are of shape (batch_size, hidden_size). - bias: - If ``False``, then the cell does not use bias weights - `bias_ih` and `bias_hh`. - ln: - Defaults to `nn.LayerNorm`. The output of all gates are processed - by `ln`. We pass it as an argument so that we can replace it - with `nn.Identity` at the testing time. - """ - super().__init__() - factory_kwargs = {"device": device, "dtype": dtype} - self.input_size = input_size - self.hidden_size = hidden_size - self.bias = bias - - self.weight_ih = nn.Parameter( - torch.empty((4 * hidden_size, input_size), **factory_kwargs) - ) - - self.weight_hh = nn.Parameter( - torch.empty((4 * hidden_size, hidden_size), **factory_kwargs) - ) - - if bias: - self.bias_ih = nn.Parameter( - torch.empty(4 * hidden_size, **factory_kwargs) - ) - self.bias_hh = nn.Parameter( - torch.empty(4 * hidden_size, **factory_kwargs) - ) - else: - self.register_parameter("bias_ih", None) - self.register_parameter("bias_hh", None) - - self.layernorm_i = ln(hidden_size) - self.layernorm_f = ln(hidden_size) - self.layernorm_cx = ln(hidden_size) - self.layernorm_cy = ln(hidden_size) - self.layernorm_o = ln(hidden_size) - - self.reset_parameters() - - def forward( - self, - input: torch.Tensor, - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - input: - A 2-D tensor of shape (batch_size, input_size). - state: - If not ``None``, it contains the hidden state (h, c); both - are of shape (batch_size, hidden_size). If ``None``, it uses - zeros for `h` and `c`. - Returns: - Return two tensors: - - `next_h`: It is of shape (batch_size, hidden_size) containing the - next hidden state for each element in the batch. - - `next_c`: It is of shape (batch_size, hidden_size) containing the - next cell state for each element in the batch. - """ - if state is None: - zeros = torch.zeros( - input.size(0), - self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - state = (zeros, zeros) - - hx, cx = state - gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear( - hx, self.weight_hh, self.bias_hh - ) - - in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1) - - in_gate = self.layernorm_i(in_gate) - forget_gate = self.layernorm_f(forget_gate) - cell_gate = self.layernorm_cx(cell_gate) - out_gate = self.layernorm_o(out_gate) - - in_gate = torch.sigmoid(in_gate) - forget_gate = torch.sigmoid(forget_gate) - cell_gate = torch.tanh(cell_gate) - out_gate = torch.sigmoid(out_gate) - - cy = (forget_gate * cx) + (in_gate * cell_gate) - cy = self.layernorm_cy(cy) - hy = out_gate * torch.tanh(cy) - - return hy, cy - - def extra_repr(self) -> str: - s = "{input_size}, {hidden_size}" - if "bias" in self.__dict__ and self.bias is not True: - s += ", bias={bias}" - return s.format(**self.__dict__) - - def reset_parameters(self) -> None: - stdv = 1.0 / math.sqrt(self.hidden_size) - for weight in self.parameters(): - nn.init.uniform_(weight, -stdv, stdv) - - -class LayerNormLSTMLayer(nn.Module): - def __init__( - self, - input_size: int, - hidden_size: int, - bias: bool = True, - ln: nn.Module = nn.LayerNorm, - device=None, - dtype=None, - ): - """ - See the args in LayerNormLSTMCell - """ - super().__init__() - self.cell = LayerNormLSTMCell( - input_size=input_size, - hidden_size=hidden_size, - bias=bias, - ln=ln, - device=device, - dtype=dtype, - ) - - def forward( - self, - input: torch.Tensor, - state: Tuple[torch.Tensor, torch.Tensor], - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Args: - input: - A 3-D tensor of shape (batch_size, seq_len, input_size). - Caution: - We use `batch_first=True` here. - state: - If not ``None``, it contains the hidden state (h, c) of this layer. - Both are of shape (batch_size, hidden_size). - Note: - We did not annotate `state` with `Optional[Tuple[...]]` since - torchscript will complain. - Return: - - output, a tensor of shape (batch_size, seq_len, hidden_size) - - (next_h, next_c) containing the hidden state of this layer - """ - inputs = input.unbind(1) - outputs = torch.jit.annotate(List[torch.Tensor], []) - for i in range(len(inputs)): - state = self.cell(inputs[i], state) - outputs += [state[0]] - return torch.stack(outputs, dim=1), state - - -class LayerNormLSTM(nn.Module): - def __init__( - self, - input_size: int, - hidden_size: int, - num_layers: int, - bias: bool = True, - ln: nn.Module = nn.LayerNorm, - device=None, - dtype=None, - ): - """ - See the args in LSTMLayer. - """ - super().__init__() - assert num_layers >= 1 - factory_kwargs = dict( - hidden_size=hidden_size, - bias=bias, - ln=ln, - device=device, - dtype=dtype, - ) - first_layer = LayerNormLSTMLayer( - input_size=input_size, **factory_kwargs - ) - layers = [first_layer] - for i in range(1, num_layers): - layers.append( - LayerNormLSTMLayer( - input_size=hidden_size, - **factory_kwargs, - ) - ) - self.layers = nn.ModuleList(layers) - self.num_layers = num_layers - - def forward( - self, - input: torch.Tensor, - states: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: - """ - Args: - input: - A 3-D tensor of shape (batch_size, seq_len, input_size). - Caution: - We use `batch_first=True` here. - states: - One state per layer. Each entry contains the hidden state (h, c) - for a layer. Both are of shape (batch_size, hidden_size). - Returns: - Return a tuple containing: - - - output: A tensor of shape (batch_size, seq_len, hidden_size) - - List[(next_h, next_c)] containing the hidden states for all layers - - """ - output_states = torch.jit.annotate( - List[Tuple[torch.Tensor, torch.Tensor]], [] - ) - output = input - for i, rnn_layer in enumerate(self.layers): - state = states[i] - output, out_state = rnn_layer(output, state) - output_states += [out_state] - return output, output_states diff --git a/egs/librispeech/ASR/rnnt/test_rnn.py b/egs/librispeech/ASR/rnnt/test_rnn.py deleted file mode 100755 index 683183fb5..000000000 --- a/egs/librispeech/ASR/rnnt/test_rnn.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from rnnt.rnn import LayerNormLSTM, LayerNormLSTMCell, LayerNormLSTMLayer - - -def test_layernorm_lstm_cell_jit(): - input_size = 10 - hidden_size = 20 - cell = LayerNormLSTMCell( - input_size=input_size, hidden_size=hidden_size, bias=True - ) - - torch.jit.script(cell) - - -def test_layernorm_lstm_cell_constructor(): - input_size = torch.randint(low=2, high=100, size=(1,)).item() - hidden_size = torch.randint(low=2, high=100, size=(1,)).item() - - self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity) - torch_cell = nn.LSTMCell(input_size, hidden_size) - - for name, param in self_cell.named_parameters(): - assert param.shape == getattr(torch_cell, name).shape - - assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) - - -def test_layernorm_lstm_cell_forward(): - input_size = torch.randint(low=2, high=100, size=(1,)).item() - hidden_size = torch.randint(low=2, high=100, size=(1,)).item() - bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 - - self_cell = LayerNormLSTMCell( - input_size, hidden_size, bias=bias, ln=nn.Identity - ) - torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) - with torch.no_grad(): - for name, torch_param in torch_cell.named_parameters(): - self_param = getattr(self_cell, name) - torch_param.copy_(self_param) - - N = torch.randint(low=2, high=100, size=(1,)) - x = torch.rand(N, input_size).requires_grad_() - h = torch.rand(N, hidden_size) - c = torch.rand(N, hidden_size) - - x_clone = x.detach().clone().requires_grad_() - - self_h, self_c = self_cell(x.clone(), (h, c)) - torch_h, torch_c = torch_cell(x_clone, (h, c)) - - assert torch.allclose(self_h, torch_h) - assert torch.allclose(self_c, torch_c) - - self_hc = self_h * self_c - torch_hc = torch_h * torch_c - (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward() - (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward() - - assert torch.allclose(x.grad, x_clone.grad) - - -def test_lstm_layer_jit(): - input_size = 10 - hidden_size = 20 - layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size) - torch.jit.script(layer) - - -def test_lstm_layer_forward(): - input_size = torch.randint(low=2, high=100, size=(1,)).item() - hidden_size = torch.randint(low=2, high=100, size=(1,)).item() - bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 - self_layer = LayerNormLSTMLayer( - input_size, - hidden_size, - bias=bias, - ln=nn.Identity, - ) - - N = torch.randint(low=2, high=100, size=(1,)) - T = torch.randint(low=2, high=100, size=(1,)) - - x = torch.rand(N, T, input_size).requires_grad_() - h = torch.rand(N, hidden_size) - c = torch.rand(N, hidden_size) - - x_clone = x.detach().clone().requires_grad_() - - self_y, (self_h, self_c) = self_layer(x, (h, c)) - - # now for pytorch - torch_layer = nn.LSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=1, - bias=bias, - batch_first=True, - dropout=0, - bidirectional=False, - ) - with torch.no_grad(): - for name, self_param in self_layer.cell.named_parameters(): - getattr(torch_layer, f"{name}_l0").copy_(self_param) - - torch_y, (torch_h, torch_c) = torch_layer( - x_clone, (h.unsqueeze(0), c.unsqueeze(0)) - ) - assert torch.allclose(self_y, torch_y) - assert torch.allclose(self_h, torch_h) - assert torch.allclose(self_c, torch_c) - - self_hc = self_h * self_c - torch_hc = torch_h * torch_c - self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum() - torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum() - - self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() - torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() - - (self_hc_sum * self_y_sum).backward() - (torch_hc_sum * torch_y_sum).backward() - - assert torch.allclose(x.grad, x_clone.grad, rtol=0.1) - - -def test_stacked_lstm_jit(): - input_size = 2 - hidden_size = 3 - num_layers = 4 - bias = True - - lstm = LayerNormLSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - bias=bias, - ln=nn.Identity, - ) - torch.jit.script(lstm) - - -def test_stacked_lstm_forward(): - input_size = torch.randint(low=2, high=100, size=(1,)).item() - hidden_size = torch.randint(low=2, high=100, size=(1,)).item() - num_layers = torch.randint(low=2, high=100, size=(1,)).item() - bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 - - self_lstm = LayerNormLSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - bias=bias, - ln=nn.Identity, - ) - torch_lstm = nn.LSTM( - input_size=input_size, - hidden_size=hidden_size, - num_layers=num_layers, - bias=bias, - batch_first=True, - bidirectional=False, - ) - assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict()) - with torch.no_grad(): - for name, param in self_lstm.named_parameters(): - # name has the form layers.0.cell.weight_hh - parts = name.split(".") - layer_num = parts[1] - getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param) - - N = torch.randint(low=2, high=100, size=(1,)) - T = torch.randint(low=2, high=100, size=(1,)) - - x = torch.rand(N, T, input_size).requires_grad_() - hs = [torch.rand(N, hidden_size) for _ in range(num_layers)] - cs = [torch.rand(N, hidden_size) for _ in range(num_layers)] - states = list(zip(hs, cs)) - - x_clone = x.detach().clone().requires_grad_() - - self_y, self_states = self_lstm(x, states) - - h = torch.stack(hs) - c = torch.stack(cs) - torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c)) - - assert torch.allclose(self_y, torch_y) - - self_h = torch.stack([s[0] for s in self_states]) - self_c = torch.stack([s[1] for s in self_states]) - - assert torch.allclose(self_h, torch_h) - assert torch.allclose(self_c, torch_c) - - s = self_y.reshape(-1) - t = torch_y.reshape(-1) - - s_sum = (s * torch.arange(s.numel())).sum() - t_sum = (t * torch.arange(t.numel())).sum() - shc_sum = s_sum * self_h.sum() * self_c.sum() - thc_sum = t_sum * torch_h.sum() * torch_c.sum() - - shc_sum.backward() - thc_sum.backward() - assert torch.allclose(x.grad, x_clone.grad) - - -def main(): - test_layernorm_lstm_cell_jit() - test_layernorm_lstm_cell_constructor() - test_layernorm_lstm_cell_forward() - # - test_lstm_layer_jit() - test_lstm_layer_forward() - # - test_stacked_lstm_jit() - test_stacked_lstm_forward() - - -if __name__ == "__main__": - main() diff --git a/egs/librispeech/ASR/rnnt/__init__.py b/egs/librispeech/ASR/transducer/__init__.py similarity index 100% rename from egs/librispeech/ASR/rnnt/__init__.py rename to egs/librispeech/ASR/transducer/__init__.py diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py new file mode 100644 index 000000000..5afb4e907 --- /dev/null +++ b/egs/librispeech/ASR/transducer/rnn.py @@ -0,0 +1,610 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Apply layer normalization to the output of each gate in LSTM/GRU. + +This file uses +https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py +as a reference. +""" + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf +class LayerNormLSTMCell(nn.Module): + """This class places a `nn.LayerNorm` after the output of + each gate (right before the activation). + + See the following paper for more details + + 'Improving RNN Transducer Modeling for End-to-End Speech Recognition' + https://arxiv.org/abs/1909.12415 + + Examples:: + + >>> cell = LayerNormLSTMCell(10, 20) + >>> input = torch.rand(5, 10) + >>> h0 = torch.rand(5, 20) + >>> c0 = torch.rand(5, 20) + >>> h1, c1 = cell(input, (h0, c0)) + >>> output = h1 + >>> h1.shape + torch.Size([5, 20]) + >>> c1.shape + torch.Size([5, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + Args: + input_size: + The number of expected features in the input `x`. `x` should + be of shape (batch_size, input_size). + hidden_size: + The number of features in the hidden state `h` and `c`. + Both `h` and `c` are of shape (batch_size, hidden_size). + bias: + If ``False``, then the cell does not use bias weights + `bias_ih` and `bias_hh`. + ln: + Defaults to `nn.LayerNorm`. The output of all gates are processed + by `ln`. We pass it as an argument so that we can replace it + with `nn.Identity` at the testing time. + """ + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + + self.weight_ih = nn.Parameter( + torch.empty((4 * hidden_size, input_size), **factory_kwargs) + ) + + self.weight_hh = nn.Parameter( + torch.empty((4 * hidden_size, hidden_size), **factory_kwargs) + ) + + if bias: + self.bias_ih = nn.Parameter( + torch.empty(4 * hidden_size, **factory_kwargs) + ) + self.bias_hh = nn.Parameter( + torch.empty(4 * hidden_size, **factory_kwargs) + ) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + self.layernorm_i = ln(hidden_size) + self.layernorm_f = ln(hidden_size) + self.layernorm_cx = ln(hidden_size) + self.layernorm_cy = ln(hidden_size) + self.layernorm_o = ln(hidden_size) + + self.reset_parameters() + + def forward( + self, + input: torch.Tensor, + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input: + A 2-D tensor of shape (batch_size, input_size). + state: + If not ``None``, it contains the hidden state (h, c) for each + element in the batch. Both are of shape (batch_size, hidden_size). + If ``None``, it uses zeros for `h` and `c`. + Returns: + Return two tensors: + - `next_h`: It is of shape (batch_size, hidden_size) containing the + next hidden state for each element in the batch. + - `next_c`: It is of shape (batch_size, hidden_size) containing the + next cell state for each element in the batch. + """ + if state is None: + zeros = torch.zeros( + input.size(0), + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + state = (zeros, zeros) + + hx, cx = state + gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear( + hx, self.weight_hh, self.bias_hh + ) + + in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1) + + in_gate = self.layernorm_i(in_gate) + forget_gate = self.layernorm_f(forget_gate) + cell_gate = self.layernorm_cx(cell_gate) + out_gate = self.layernorm_o(out_gate) + + in_gate = torch.sigmoid(in_gate) + forget_gate = torch.sigmoid(forget_gate) + cell_gate = torch.tanh(cell_gate) + out_gate = torch.sigmoid(out_gate) + + cy = (forget_gate * cx) + (in_gate * cell_gate) + cy = self.layernorm_cy(cy) + hy = out_gate * torch.tanh(cy) + + return hy, cy + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + + +class LayerNormLSTMLayer(nn.Module): + """ + Examples:: + + >>> layer = LayerNormLSTMLayer(10, 20) + >>> input = torch.rand(2, 5, 10) + >>> h0 = torch.rand(2, 20) + >>> c0 = torch.rand(2, 20) + >>> output, (hn, cn) = layer(input, (h0, c0)) + >>> output.shape + torch.Size([2, 5, 20]) + >>> hn.shape + torch.Size([2, 20]) + >>> cn.shape + torch.Size([2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LayerNormLSTMCell + """ + super().__init__() + self.cell = LayerNormLSTMCell( + input_size=input_size, + hidden_size=hidden_size, + bias=bias, + ln=ln, + device=device, + dtype=dtype, + ) + + def forward( + self, + input: torch.Tensor, + state: Tuple[torch.Tensor, torch.Tensor], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + input: + A 3-D tensor of shape (batch_size, seq_len, input_size). + Caution: + We use `batch_first=True` here. + state: + If not ``None``, it contains the hidden state (h, c) of this layer. + Both are of shape (batch_size, hidden_size). + Note: + We did not annotate `state` with `Optional[Tuple[...]]` since + torchscript will complain. + Return: + - output, a tensor of shape (batch_size, seq_len, hidden_size) + - (next_h, next_c) containing the hidden state of this layer + """ + inputs = input.unbind(1) + outputs = torch.jit.annotate(List[torch.Tensor], []) + for i in range(len(inputs)): + state = self.cell(inputs[i], state) + outputs.append(state[0]) + return torch.stack(outputs, dim=1), state + + +class LayerNormLSTM(nn.Module): + """ + Examples:: + + >>> lstm = LayerNormLSTM(10, 20, 8) + >>> input = torch.rand(2, 3, 10) + >>> h0 = torch.rand(8, 2, 20).unbind(0) + >>> c0 = torch.rand(8, 2, 20).unbind(0) + >>> states = list(zip(h0, c0)) + >>> output, next_states = lstm(input, states) + >>> output.shape + torch.Size([2, 3, 20]) + >>> hn = torch.stack([s[0] for s in next_states]) + >>> cn = torch.stack([s[1] for s in next_states]) + >>> hn.shape + torch.Size([8, 2, 20]) + >>> cn.shape + torch.Size([8, 2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LayerNormLSTMLayer. + """ + super().__init__() + assert num_layers >= 1 + factory_kwargs = dict( + hidden_size=hidden_size, + bias=bias, + ln=ln, + device=device, + dtype=dtype, + ) + first_layer = LayerNormLSTMLayer( + input_size=input_size, **factory_kwargs + ) + layers = [first_layer] + for i in range(1, num_layers): + layers.append( + LayerNormLSTMLayer( + input_size=hidden_size, + **factory_kwargs, + ) + ) + self.layers = nn.ModuleList(layers) + self.num_layers = num_layers + + def forward( + self, + input: torch.Tensor, + states: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """ + Args: + input: + A 3-D tensor of shape (batch_size, seq_len, input_size). + Caution: + We use `batch_first=True` here. + states: + One state per layer. Each entry contains the hidden state (h, c) + for a layer. Both are of shape (batch_size, hidden_size). + Returns: + Return a tuple containing: + + - output: A tensor of shape (batch_size, seq_len, hidden_size) + - List[(next_h, next_c)] containing the hidden states for all layers + + """ + output_states = torch.jit.annotate( + List[Tuple[torch.Tensor, torch.Tensor]], [] + ) + output = input + for i, rnn_layer in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states + + +class LayerNormGRUCell(nn.Module): + """This class places a `nn.LayerNorm` after the output of + each gate (right before the activation). + + See the following paper for more details + + 'Improving RNN Transducer Modeling for End-to-End Speech Recognition' + https://arxiv.org/abs/1909.12415 + + Examples:: + + >>> cell = LayerNormGRUCell(10, 20) + >>> input = torch.rand(2, 10) + >>> h0 = torch.rand(2, 20) + >>> hn = cell(input, h0) + >>> hn.shape + torch.Size([2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + Args: + input_size: + The number of expected features in the input `x`. `x` should + be of shape (batch_size, input_size). + hidden_size: + The number of features in the hidden state `h` and `c`. + Both `h` and `c` are of shape (batch_size, hidden_size). + bias: + If ``False``, then the cell does not use bias weights + `bias_ih` and `bias_hh`. + ln: + Defaults to `nn.LayerNorm`. The output of all gates are processed + by `ln`. We pass it as an argument so that we can replace it + with `nn.Identity` at the testing time. + """ + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + + self.weight_ih = nn.Parameter( + torch.empty((3 * hidden_size, input_size), **factory_kwargs) + ) + + self.weight_hh = nn.Parameter( + torch.empty((3 * hidden_size, hidden_size), **factory_kwargs) + ) + + if bias: + self.bias_ih = nn.Parameter( + torch.empty(3 * hidden_size, **factory_kwargs) + ) + self.bias_hh = nn.Parameter( + torch.empty(3 * hidden_size, **factory_kwargs) + ) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + self.layernorm_r = ln(hidden_size) + self.layernorm_i = ln(hidden_size) + self.layernorm_n = ln(hidden_size) + + self.reset_parameters() + + def forward( + self, + input: torch.Tensor, + hx: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + input: + A 2-D tensor of shape (batch_size, input_size) containing + input features. + hx: + If not `None`, it is a tensor of shape (batch_size, hidden_size) + containing the initial hidden state for each element in the batch. + If `None`, it uses zeros for the hidden state. + Returns: + Return a tensor of shape (batch_size, hidden_size) containing the + next hidden state for each element in the batch + """ + if hx is None: + hx = torch.zeros( + input.size(0), + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + + i_r, i_i, i_n = F.linear(input, self.weight_ih, self.bias_ih).chunk( + chunks=3, dim=1 + ) + + h_r, h_i, h_n = F.linear(hx, self.weight_hh, self.bias_hh).chunk( + chunks=3, dim=1 + ) + + reset_gate = torch.sigmoid(self.layernorm_r(i_r + h_r)) + input_gate = torch.sigmoid(self.layernorm_i(i_i + h_i)) + new_gate = torch.tanh(self.layernorm_n(i_n + reset_gate * h_n)) + + # hy = (1 - input_gate) * new_gate + input_gate * hx + # = new_gate - input_gate * new_gate + input_gate * hx + # = new_gate + input_gate * (hx - new_gate) + hy = new_gate + input_gate * (hx - new_gate) + + return hy + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + + +class LayerNormGRULayer(nn.Module): + """ + Examples:: + + >>> layer = LayerNormGRULayer(10, 20) + >>> input = torch.rand(2, 3, 10) + >>> hx = torch.rand(2, 20) + >>> output, hn = layer(input, hx) + >>> output.shape + torch.Size([2, 3, 20]) + >>> hn.shape + torch.Size([2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LayerNormGRUCell + """ + super().__init__() + self.cell = LayerNormGRUCell( + input_size=input_size, + hidden_size=hidden_size, + bias=bias, + ln=ln, + device=device, + dtype=dtype, + ) + + def forward( + self, + input: torch.Tensor, + hx: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input: + A 3-D tensor of shape (batch_size, seq_len, input_size). + Caution: + We use `batch_first=True` here. + hx: + If not ``None``, it is a tensor of shape (batch_size, hidden_size) + containing the hidden state for each element in the batch. + Return: + - output, a tensor of shape (batch_size, seq_len, hidden_size) + - next_h, a tensor of shape (batch_size, hidden_size) containing the + final hidden state for each element in the batch. + """ + inputs = input.unbind(1) + outputs = torch.jit.annotate(List[torch.Tensor], []) + next_h = hx + for i in range(len(inputs)): + next_h = self.cell(inputs[i], next_h) + outputs.append(next_h) + return torch.stack(outputs, dim=1), next_h + + +class LayerNormGRU(nn.Module): + """ + Examples:: + + >>> input = torch.rand(2, 3, 10) + >>> h0 = torch.rand(8, 2, 20) + >>> states = h0.unbind(0) + >>> output, next_states = gru(input, states) + >>> output.shape + torch.Size([2, 3, 20]) + >>> hn = torch.stack(next_states) + >>> hn.shape + torch.Size([8, 2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LayerNormGRULayer. + """ + super().__init__() + assert num_layers >= 1 + factory_kwargs = dict( + hidden_size=hidden_size, + bias=bias, + ln=ln, + device=device, + dtype=dtype, + ) + first_layer = LayerNormGRULayer(input_size=input_size, **factory_kwargs) + layers = [first_layer] + for i in range(1, num_layers): + layers.append( + LayerNormGRULayer( + input_size=hidden_size, + **factory_kwargs, + ) + ) + self.layers = nn.ModuleList(layers) + self.num_layers = num_layers + + def forward( + self, + input: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + input: + A tensor of shape (batch_size, seq_len, input_size) containing + input features. + Caution: + We use `batch_first=True` here. + states: + One state per layer. Each entry contains the hidden state for each + element in the batch. Each hidden state is of shape + (batch_size, hidden_size) + Returns: + Return a tuple containing: + + - output: A tensor of shape (batch_size, seq_len, hidden_size) + - List[next_state] containing the final hidden states for each + element in the batch + + """ + output_states = torch.jit.annotate(List[torch.Tensor], []) + output = input + for i, rnn_layer in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py new file mode 100755 index 000000000..7df7c7b78 --- /dev/null +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transducer.rnn import ( + LayerNormGRU, + LayerNormGRUCell, + LayerNormGRULayer, + LayerNormLSTM, + LayerNormLSTMCell, + LayerNormLSTMLayer, +) + + +def assert_allclose(a: torch.Tensor, b: torch.Tensor, **kwargs): + assert torch.allclose(a, b, **kwargs), f"{(a - b).abs().max()}, {a.numel()}" + + +def test_layernorm_lstm_cell_jit(): + input_size = 10 + hidden_size = 20 + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + cell = LayerNormLSTMCell( + input_size=input_size, hidden_size=hidden_size, bias=bias + ) + + torch.jit.script(cell) + + +def test_layernorm_lstm_cell_constructor(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + + self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity) + torch_cell = nn.LSTMCell(input_size, hidden_size) + + for name, param in self_cell.named_parameters(): + assert param.shape == getattr(torch_cell, name).shape + + assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) + + +def test_layernorm_lstm_cell_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_cell = LayerNormLSTMCell( + input_size, hidden_size, bias=bias, ln=nn.Identity + ) + torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) + with torch.no_grad(): + for name, torch_param in torch_cell.named_parameters(): + self_param = getattr(self_cell, name) + torch_param.copy_(self_param) + + N = torch.randint(low=2, high=100, size=(1,)) + x = torch.rand(N, input_size).requires_grad_() + h = torch.rand(N, hidden_size) + c = torch.rand(N, hidden_size) + + x_clone = x.detach().clone().requires_grad_() + + self_h, self_c = self_cell(x.clone(), (h, c)) + torch_h, torch_c = torch_cell(x_clone, (h, c)) + + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + self_hc = self_h * self_c + torch_hc = torch_h * torch_c + (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward() + (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward() + + assert_allclose(x.grad, x_clone.grad) + + +def test_layernorm_lstm_layer_jit(): + input_size = 10 + hidden_size = 20 + layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size) + torch.jit.script(layer) + + +def test_layernorm_lstm_layer_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + self_layer = LayerNormLSTMLayer( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + ) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size).requires_grad_() + h = torch.rand(N, hidden_size) + c = torch.rand(N, hidden_size) + + x_clone = x.detach().clone().requires_grad_() + + self_y, (self_h, self_c) = self_layer(x, (h, c)) + + torch_layer = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=bias, + batch_first=True, + dropout=0, + bidirectional=False, + ) + with torch.no_grad(): + for name, self_param in self_layer.cell.named_parameters(): + getattr(torch_layer, f"{name}_l0").copy_(self_param) + + torch_y, (torch_h, torch_c) = torch_layer( + x_clone, (h.unsqueeze(0), c.unsqueeze(0)) + ) + assert_allclose(self_y, torch_y) + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + self_hc = self_h * self_c + torch_hc = torch_h * torch_c + self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum() + torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum() + + self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() + torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() + + (self_hc_sum * self_y_sum).backward() + (torch_hc_sum * torch_y_sum).backward() + + assert_allclose(x.grad, x_clone.grad, atol=1e-5) + + +def test_layernorm_lstm_jit(): + input_size = 2 + hidden_size = 3 + num_layers = 4 + bias = True + + lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + ) + torch.jit.script(lstm) + + +def test_layernorm_lstm_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + ) + torch_lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=True, + bidirectional=False, + ) + assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict()) + with torch.no_grad(): + for name, param in self_lstm.named_parameters(): + # name has the form layers.0.cell.weight_hh + parts = name.split(".") + layer_num = parts[1] + getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size).requires_grad_() + hs = [torch.rand(N, hidden_size) for _ in range(num_layers)] + cs = [torch.rand(N, hidden_size) for _ in range(num_layers)] + states = list(zip(hs, cs)) + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_states = self_lstm(x, states) + + h = torch.stack(hs) + c = torch.stack(cs) + torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c)) + + assert_allclose(self_y, torch_y) + + self_h = torch.stack([s[0] for s in self_states]) + self_c = torch.stack([s[1] for s in self_states]) + + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + s = self_y.reshape(-1) + t = torch_y.reshape(-1) + + s_sum = (s * torch.arange(s.numel())).sum() + t_sum = (t * torch.arange(t.numel())).sum() + shc_sum = s_sum * self_h.sum() * self_c.sum() + thc_sum = t_sum * torch_h.sum() * torch_c.sum() + + shc_sum.backward() + thc_sum.backward() + assert_allclose(x.grad, x_clone.grad) + + +def test_layernorm_gru_cell_jit(): + input_size = 10 + hidden_size = 20 + cell = LayerNormGRUCell( + input_size=input_size, hidden_size=hidden_size, bias=True + ) + + torch.jit.script(cell) + + +def test_layernorm_gru_cell_constructor(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + + self_cell = LayerNormGRUCell(input_size, hidden_size, ln=nn.Identity) + torch_cell = nn.GRUCell(input_size, hidden_size) + + for name, param in self_cell.named_parameters(): + assert param.shape == getattr(torch_cell, name).shape + + assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) + + +def test_layernorm_gru_cell_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_cell = LayerNormGRUCell( + input_size, hidden_size, bias=bias, ln=nn.Identity + ) + torch_cell = nn.GRUCell(input_size, hidden_size, bias=bias) + with torch.no_grad(): + for name, torch_param in torch_cell.named_parameters(): + self_param = getattr(self_cell, name) + torch_param.copy_(self_param) + + N = torch.randint(low=2, high=100, size=(1,)) + x = torch.rand(N, input_size).requires_grad_() + h = torch.rand(N, hidden_size) + + x_clone = x.detach().clone().requires_grad_() + + self_h = self_cell(x.clone(), h) + torch_h = torch_cell(x_clone, h) + + assert_allclose(self_h, torch_h, atol=1e-5) + + (self_h.reshape(-1) * torch.arange(self_h.numel())).sum().backward() + (torch_h.reshape(-1) * torch.arange(torch_h.numel())).sum().backward() + + assert_allclose(x.grad, x_clone.grad, atol=1e-4) + + +def test_layernorm_gru_layer_jit(): + input_size = 10 + hidden_size = 20 + layer = LayerNormGRULayer(input_size, hidden_size=hidden_size) + torch.jit.script(layer) + + +def test_layernorm_gru_layer_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + self_layer = LayerNormGRULayer( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + ) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size).requires_grad_() + h = torch.rand(N, hidden_size) + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_h = self_layer(x, h.clone()) + + torch_layer = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=bias, + batch_first=True, + dropout=0, + bidirectional=False, + ) + with torch.no_grad(): + for name, self_param in self_layer.cell.named_parameters(): + getattr(torch_layer, f"{name}_l0").copy_(self_param) + + torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0)) + assert_allclose(self_y, torch_y, atol=1e-6) + assert_allclose(self_h, torch_h) + + self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() + torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() + + self_y_sum.backward() + torch_y_sum.backward() + + assert_allclose(x.grad, x_clone.grad, atol=0.1) + + +def test_layernorm_gru_jit(): + input_size = 2 + hidden_size = 3 + num_layers = 4 + bias = True + + gru = LayerNormGRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + ) + torch.jit.script(gru) + + +def test_layernorm_gru_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_gru = LayerNormGRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + ) + torch_gru = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=True, + bidirectional=False, + ) + assert len(self_gru.state_dict()) == len(torch_gru.state_dict()) + with torch.no_grad(): + for name, param in self_gru.named_parameters(): + # name has the form layers.0.cell.weight_hh + parts = name.split(".") + layer_num = parts[1] + getattr(torch_gru, f"{parts[-1]}_l{layer_num}").copy_(param) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size).requires_grad_() + states = [torch.rand(N, hidden_size) for _ in range(num_layers)] + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_states = self_gru(x, states) + + torch_y, torch_states = torch_gru(x_clone, torch.stack(states)) + + assert_allclose(self_y, torch_y, atol=1e-6) + + self_states = torch.stack(self_states) + + assert_allclose(self_states, torch_states, atol=1e-6) + + s = self_y.reshape(-1) + t = torch_y.reshape(-1) + + s_sum = (s * torch.arange(s.numel())).sum() + t_sum = (t * torch.arange(t.numel())).sum() + s_state_sum = s_sum + self_states.sum() + t_state_sum = t_sum + torch_states.sum() + + s_state_sum.backward() + t_state_sum.backward() + assert_allclose(x.grad, x_clone.grad) + + +def test_lstm(): + test_layernorm_lstm_cell_jit() + test_layernorm_lstm_cell_constructor() + test_layernorm_lstm_cell_forward() + # + test_layernorm_lstm_layer_jit() + test_layernorm_lstm_layer_forward() + # + test_layernorm_lstm_jit() + test_layernorm_lstm_forward() + + +def test_gru(): + test_layernorm_gru_cell_jit() + test_layernorm_gru_cell_constructor() + test_layernorm_gru_cell_forward() + # + test_layernorm_gru_layer_jit() + test_layernorm_gru_layer_forward() + # + test_layernorm_gru_jit() + test_layernorm_gru_forward() + + +def main(): + test_lstm() + test_gru() + + +if __name__ == "__main__": + torch.manual_seed(20211202) + main()