diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e897c3fb5..b9c0956f0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -103,6 +103,9 @@ jobs: cd egs/librispeech/ASR/conformer_ctc pytest -v -s + cd .. + pytest -v -s ./transducer + - name: Run tests if: startsWith(matrix.os, 'macos') run: | @@ -113,6 +116,9 @@ jobs: export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH pytest -v -s ./test - # runt tests for conformer ctc + # run tests for conformer ctc cd egs/librispeech/ASR/conformer_ctc pytest -v -s + + cd .. + pytest -v -s ./transducer diff --git a/egs/librispeech/ASR/transducer/__init__.py b/egs/librispeech/ASR/transducer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py new file mode 100644 index 000000000..8e695db50 --- /dev/null +++ b/egs/librispeech/ASR/transducer/rnn.py @@ -0,0 +1,659 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Apply layer normalization to the output of each gate in LSTM/GRU. + +This file uses +https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py +as a reference. +""" + +import math +from typing import List, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typeguard import check_argument_types + + +class LayerNormLSTMCell(nn.Module): + """This class places a `nn.LayerNorm` after the output of + each gate (right before the activation). + + See the following paper for more details + + 'Improving RNN Transducer Modeling for End-to-End Speech Recognition' + https://arxiv.org/abs/1909.12415 + + Examples:: + + >>> cell = LayerNormLSTMCell(10, 20) + >>> input = torch.rand(5, 10) + >>> h0 = torch.rand(5, 20) + >>> c0 = torch.rand(5, 20) + >>> h1, c1 = cell(input, (h0, c0)) + >>> output = h1 + >>> h1.shape + torch.Size([5, 20]) + >>> c1.shape + torch.Size([5, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: Type[nn.Module] = nn.LayerNorm, + proj_size: int = 0, + device=None, + dtype=None, + ): + """ + Args: + input_size: + The number of expected features in the input `x`. `x` should + be of shape (batch_size, input_size). + hidden_size: + The number of features in the hidden state `h` and `c`. + Both `h` and `c` are of shape (batch_size, hidden_size) when + proj_size is 0. If proj_size is not zero, the shape of `h` + is (batch_size, proj_size). + bias: + If ``False``, then the cell does not use bias weights + `bias_ih` and `bias_hh`. + ln: + Defaults to `nn.LayerNorm`. The output of all gates are processed + by `ln`. We pass it as an argument so that we can replace it + with `nn.Identity` at the testing time. + proj_size: + If not zero, it applies an affine transform to the output. In this + case, the shape of `h` is (batch_size, proj_size). + See https://arxiv.org/pdf/1402.1128.pdf + """ + assert check_argument_types() + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + self.proj_size = proj_size + + if proj_size < 0: + raise ValueError( + f"proj_size {proj_size} should be a positive integer " + "or zero to disable projections" + ) + + if proj_size >= hidden_size: + raise ValueError( + f"proj_size {proj_size} has to be smaller " + f"than hidden_size {hidden_size}" + ) + + real_hidden_size = proj_size if proj_size > 0 else hidden_size + + self.weight_ih = nn.Parameter( + torch.empty((4 * hidden_size, input_size), **factory_kwargs) + ) + + self.weight_hh = nn.Parameter( + torch.empty((4 * hidden_size, real_hidden_size), **factory_kwargs) + ) + + if bias: + self.bias_ih = nn.Parameter( + torch.empty(4 * hidden_size, **factory_kwargs) + ) + self.bias_hh = nn.Parameter( + torch.empty(4 * hidden_size, **factory_kwargs) + ) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + if proj_size > 0: + self.weight_hr = nn.Parameter( + torch.empty((proj_size, hidden_size), **factory_kwargs) + ) + else: + self.register_parameter("weight_hr", None) + + self.layernorm_i = ln(hidden_size) + self.layernorm_f = ln(hidden_size) + self.layernorm_cx = ln(hidden_size) + self.layernorm_cy = ln(hidden_size) + self.layernorm_o = ln(hidden_size) + + self.reset_parameters() + + def forward( + self, + input: torch.Tensor, + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input: + A 2-D tensor of shape (batch_size, input_size). + state: + If not ``None``, it contains the hidden state (h, c) for each + element in the batch. Both are of shape (batch_size, hidden_size) + if proj_size is 0. If proj_size is not zero, the shape of `h` is + (batch_size, proj_size). + If ``None``, it uses zeros for `h` and `c`. + Returns: + Return two tensors: + - `next_h`: It is of shape (batch_size, hidden_size) if proj_size + is 0, else (batch_size, proj_size), containing the next hidden + state for each element in the batch. + - `next_c`: It is of shape (batch_size, hidden_size) containing the + next cell state for each element in the batch. + """ + if state is None: + zeros = torch.zeros( + input.size(0), + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + state = (zeros, zeros) + + hx, cx = state + gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear( + hx, self.weight_hh, self.bias_hh + ) + + in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1) + + in_gate = self.layernorm_i(in_gate) + forget_gate = self.layernorm_f(forget_gate) + cell_gate = self.layernorm_cx(cell_gate) + out_gate = self.layernorm_o(out_gate) + + in_gate = torch.sigmoid(in_gate) + forget_gate = torch.sigmoid(forget_gate) + cell_gate = torch.tanh(cell_gate) + out_gate = torch.sigmoid(out_gate) + + cy = (forget_gate * cx) + (in_gate * cell_gate) + cy = self.layernorm_cy(cy) + hy = out_gate * torch.tanh(cy) + + if self.weight_hr is not None: + hy = torch.matmul(hy, self.weight_hr.t()) + + return hy, cy + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) + for name, weight in self.named_parameters(): + if "layernorm" not in name: + nn.init.uniform_(weight, -stdv, stdv) + + +class LayerNormLSTMLayer(nn.Module): + """ + Examples:: + + >>> layer = LayerNormLSTMLayer(10, 20) + >>> input = torch.rand(2, 5, 10) + >>> h0 = torch.rand(2, 20) + >>> c0 = torch.rand(2, 20) + >>> output, (hn, cn) = layer(input, (h0, c0)) + >>> output.shape + torch.Size([2, 5, 20]) + >>> hn.shape + torch.Size([2, 20]) + >>> cn.shape + torch.Size([2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: Type[nn.Module] = nn.LayerNorm, + proj_size: int = 0, + device=None, + dtype=None, + ): + """ + See the args in LayerNormLSTMCell + """ + assert check_argument_types() + super().__init__() + self.cell = LayerNormLSTMCell( + input_size=input_size, + hidden_size=hidden_size, + bias=bias, + ln=ln, + proj_size=proj_size, + device=device, + dtype=dtype, + ) + + def forward( + self, + input: torch.Tensor, + state: Tuple[torch.Tensor, torch.Tensor], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + input: + A 3-D tensor of shape (batch_size, seq_len, input_size). + Caution: + We use `batch_first=True` here. + state: + If not ``None``, it contains the hidden state (h, c) of this layer. + Both are of shape (batch_size, hidden_size) if proj_size is 0. + If proj_size is not 0, the shape of `h` is (batch_size, proj_size). + Note: + We did not annotate `state` with `Optional[Tuple[...]]` since + torchscript will complain. + Return: + - output, a tensor of shape (batch_size, seq_len, hidden_size) + - (next_h, next_c) containing the next hidden state + """ + inputs = input.unbind(1) + outputs = torch.jit.annotate(List[torch.Tensor], []) + for i in range(len(inputs)): + state = self.cell(inputs[i], state) + outputs.append(state[0]) + return torch.stack(outputs, dim=1), state + + +class LayerNormLSTM(nn.Module): + """ + Examples:: + + >>> lstm = LayerNormLSTM(10, 20, 8) + >>> input = torch.rand(2, 3, 10) + >>> h0 = torch.rand(8, 2, 20).unbind(0) + >>> c0 = torch.rand(8, 2, 20).unbind(0) + >>> states = list(zip(h0, c0)) + >>> output, next_states = lstm(input, states) + >>> output.shape + torch.Size([2, 3, 20]) + >>> hn = torch.stack([s[0] for s in next_states]) + >>> cn = torch.stack([s[1] for s in next_states]) + >>> hn.shape + torch.Size([8, 2, 20]) + >>> cn.shape + torch.Size([8, 2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + bias: bool = True, + proj_size: int = 0, + ln: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LayerNormLSTMLayer. + """ + assert check_argument_types() + super().__init__() + assert num_layers >= 1 + factory_kwargs = dict( + hidden_size=hidden_size, + bias=bias, + ln=ln, + proj_size=proj_size, + device=device, + dtype=dtype, + ) + first_layer = LayerNormLSTMLayer( + input_size=input_size, **factory_kwargs + ) + layers = [first_layer] + for i in range(1, num_layers): + layers.append( + LayerNormLSTMLayer( + input_size=proj_size if proj_size > 0 else hidden_size, + **factory_kwargs, + ) + ) + self.layers = nn.ModuleList(layers) + self.num_layers = num_layers + + def forward( + self, + input: torch.Tensor, + states: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """ + Args: + input: + A 3-D tensor of shape (batch_size, seq_len, input_size). + Caution: + We use `batch_first=True` here. + states: + One state per layer. Each entry contains the hidden state (h, c) + for a layer. Both are of shape (batch_size, hidden_size) if + proj_size is 0. If proj_size is not 0, the shape of `h` is + (batch_size, proj_size). + Returns: + Return a tuple containing: + + - output: A tensor of shape (batch_size, seq_len, hidden_size) + - List[(next_h, next_c)] containing the hidden states for all layers + + """ + output_states = torch.jit.annotate( + List[Tuple[torch.Tensor, torch.Tensor]], [] + ) + output = input + for i, rnn_layer in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states + + +class LayerNormGRUCell(nn.Module): + """This class places a `nn.LayerNorm` after the output of + each gate (right before the activation). + + See the following paper for more details + + 'Improving RNN Transducer Modeling for End-to-End Speech Recognition' + https://arxiv.org/abs/1909.12415 + + Examples:: + + >>> cell = LayerNormGRUCell(10, 20) + >>> input = torch.rand(2, 10) + >>> h0 = torch.rand(2, 20) + >>> hn = cell(input, h0) + >>> hn.shape + torch.Size([2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + Args: + input_size: + The number of expected features in the input `x`. `x` should + be of shape (batch_size, input_size). + hidden_size: + The number of features in the hidden state `h` and `c`. + Both `h` and `c` are of shape (batch_size, hidden_size). + bias: + If ``False``, then the cell does not use bias weights + `bias_ih` and `bias_hh`. + ln: + Defaults to `nn.LayerNorm`. The output of all gates are processed + by `ln`. We pass it as an argument so that we can replace it + with `nn.Identity` at the testing time. + """ + assert check_argument_types() + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + + self.weight_ih = nn.Parameter( + torch.empty((3 * hidden_size, input_size), **factory_kwargs) + ) + + self.weight_hh = nn.Parameter( + torch.empty((3 * hidden_size, hidden_size), **factory_kwargs) + ) + + if bias: + self.bias_ih = nn.Parameter( + torch.empty(3 * hidden_size, **factory_kwargs) + ) + self.bias_hh = nn.Parameter( + torch.empty(3 * hidden_size, **factory_kwargs) + ) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + self.layernorm_r = ln(hidden_size) + self.layernorm_i = ln(hidden_size) + self.layernorm_n = ln(hidden_size) + + self.reset_parameters() + + def forward( + self, + input: torch.Tensor, + hx: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + input: + A 2-D tensor of shape (batch_size, input_size) containing + input features. + hx: + If not `None`, it is a tensor of shape (batch_size, hidden_size) + containing the initial hidden state for each element in the batch. + If `None`, it uses zeros for the hidden state. + Returns: + Return a tensor of shape (batch_size, hidden_size) containing the + next hidden state for each element in the batch + """ + if hx is None: + hx = torch.zeros( + input.size(0), + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + + i_r, i_i, i_n = F.linear(input, self.weight_ih, self.bias_ih).chunk( + chunks=3, dim=1 + ) + + h_r, h_i, h_n = F.linear(hx, self.weight_hh, self.bias_hh).chunk( + chunks=3, dim=1 + ) + + reset_gate = torch.sigmoid(self.layernorm_r(i_r + h_r)) + input_gate = torch.sigmoid(self.layernorm_i(i_i + h_i)) + new_gate = torch.tanh(self.layernorm_n(i_n + reset_gate * h_n)) + + # hy = (1 - input_gate) * new_gate + input_gate * hx + # = new_gate - input_gate * new_gate + input_gate * hx + # = new_gate + input_gate * (hx - new_gate) + hy = new_gate + input_gate * (hx - new_gate) + + return hy + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + + +class LayerNormGRULayer(nn.Module): + """ + Examples:: + + >>> layer = LayerNormGRULayer(10, 20) + >>> input = torch.rand(2, 3, 10) + >>> hx = torch.rand(2, 20) + >>> output, hn = layer(input, hx) + >>> output.shape + torch.Size([2, 3, 20]) + >>> hn.shape + torch.Size([2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LayerNormGRUCell + """ + assert check_argument_types() + super().__init__() + self.cell = LayerNormGRUCell( + input_size=input_size, + hidden_size=hidden_size, + bias=bias, + ln=ln, + device=device, + dtype=dtype, + ) + + def forward( + self, + input: torch.Tensor, + hx: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input: + A 3-D tensor of shape (batch_size, seq_len, input_size). + Caution: + We use `batch_first=True` here. + hx: + If not ``None``, it is a tensor of shape (batch_size, hidden_size) + containing the hidden state for each element in the batch. + Return: + - output, a tensor of shape (batch_size, seq_len, hidden_size) + - next_h, a tensor of shape (batch_size, hidden_size) containing the + final hidden state for each element in the batch. + """ + inputs = input.unbind(1) + outputs = torch.jit.annotate(List[torch.Tensor], []) + next_h = hx + for i in range(len(inputs)): + next_h = self.cell(inputs[i], next_h) + outputs.append(next_h) + return torch.stack(outputs, dim=1), next_h + + +class LayerNormGRU(nn.Module): + """ + Examples:: + + >>> input = torch.rand(2, 3, 10) + >>> h0 = torch.rand(8, 2, 20) + >>> states = h0.unbind(0) + >>> output, next_states = gru(input, states) + >>> output.shape + torch.Size([2, 3, 20]) + >>> hn = torch.stack(next_states) + >>> hn.shape + torch.Size([8, 2, 20]) + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + bias: bool = True, + ln: Type[nn.Module] = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LayerNormGRULayer. + """ + assert check_argument_types() + super().__init__() + assert num_layers >= 1 + factory_kwargs = dict( + hidden_size=hidden_size, + bias=bias, + ln=ln, + device=device, + dtype=dtype, + ) + first_layer = LayerNormGRULayer(input_size=input_size, **factory_kwargs) + layers = [first_layer] + for i in range(1, num_layers): + layers.append( + LayerNormGRULayer( + input_size=hidden_size, + **factory_kwargs, + ) + ) + self.layers = nn.ModuleList(layers) + self.num_layers = num_layers + + def forward( + self, + input: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + input: + A tensor of shape (batch_size, seq_len, input_size) containing + input features. + Caution: + We use `batch_first=True` here. + states: + One state per layer. Each entry contains the hidden state for each + element in the batch. Each hidden state is of shape + (batch_size, hidden_size) + Returns: + Return a tuple containing: + + - output: A tensor of shape (batch_size, seq_len, hidden_size) + - List[next_state] containing the final hidden states for each + element in the batch + + """ + output_states = torch.jit.annotate(List[torch.Tensor], []) + output = input + for i, rnn_layer in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py new file mode 100755 index 000000000..c7d524f7d --- /dev/null +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -0,0 +1,765 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transducer.rnn import ( + LayerNormGRU, + LayerNormGRUCell, + LayerNormGRULayer, + LayerNormLSTM, + LayerNormLSTMCell, + LayerNormLSTMLayer, +) + + +def get_devices(): + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda", 0)) + return devices + + +def assert_allclose(a: torch.Tensor, b: torch.Tensor, atol=1e-6, **kwargs): + assert torch.allclose( + a, b, atol=atol, **kwargs + ), f"{(a - b).abs().max()}, {a.numel()}" + + +def test_layernorm_lstm_cell_jit(device="cpu"): + input_size = 10 + hidden_size = 20 + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + cell = LayerNormLSTMCell( + input_size=input_size, + hidden_size=hidden_size, + bias=bias, + device=device, + ) + + torch.jit.script(cell) + + +def test_layernorm_lstm_cell_constructor(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + + self_cell = LayerNormLSTMCell( + input_size, + hidden_size, + ln=nn.Identity, + device=device, + ) + torch_cell = nn.LSTMCell( + input_size, + hidden_size, + ).to(device) + + for name, param in self_cell.named_parameters(): + assert param.shape == getattr(torch_cell, name).shape + + assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) + + +def test_layernorm_lstm_cell_with_projection_jit(device="cpu"): + input_size = 10 + hidden_size = 20 + proj_size = 5 + self_cell = LayerNormLSTMCell( + input_size, + hidden_size, + proj_size=proj_size, + device=device, + ) + torch.jit.script(self_cell) + + +def test_layernorm_lstm_cell_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_cell = LayerNormLSTMCell( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + device=device, + ) + torch_cell = nn.LSTMCell( + input_size, + hidden_size, + bias=bias, + ).to(device) + with torch.no_grad(): + for name, torch_param in torch_cell.named_parameters(): + self_param = getattr(self_cell, name) + torch_param.copy_(self_param) + + N = torch.randint(low=2, high=100, size=(1,)) + x = torch.rand(N, input_size, device=device).requires_grad_() + h = torch.rand(N, hidden_size, device=device) + c = torch.rand(N, hidden_size, device=device) + + x_clone = x.detach().clone().requires_grad_() + + self_h, self_c = self_cell(x.clone(), (h, c)) + torch_h, torch_c = torch_cell(x_clone, (h, c)) + + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + self_hc = self_h * self_c + torch_hc = torch_h * torch_c + ( + self_hc.reshape(-1) * torch.arange(self_hc.numel(), device=device) + ).sum().backward() + ( + torch_hc.reshape(-1) * torch.arange(torch_hc.numel(), device=device) + ).sum().backward() + + assert_allclose(x.grad, x_clone.grad, atol=1e-3) + + +def test_layernorm_lstm_cell_with_projection_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=10, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() + + self_cell = LayerNormLSTMCell( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + proj_size=proj_size, + device=device, + ) + torch_cell = nn.LSTM( + input_size, + hidden_size, + bias=bias, + proj_size=proj_size, + batch_first=True, + ).to(device) + with torch.no_grad(): + for name, self_param in self_cell.named_parameters(): + getattr(torch_cell, f"{name}_l0").copy_(self_param) + + N = torch.randint(low=2, high=100, size=(1,)) + x = torch.rand(N, input_size, device=device).requires_grad_() + h = torch.rand(N, proj_size, device=device) + c = torch.rand(N, hidden_size, device=device) + + x_clone = x.detach().clone().requires_grad_() + + self_h, self_c = self_cell(x.clone(), (h, c)) + _, (torch_h, torch_c) = torch_cell( + x_clone.unsqueeze(1), (h.unsqueeze(0), c.unsqueeze(0)) + ) + + torch_h = torch_h.squeeze(0) + torch_c = torch_c.squeeze(0) + + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + (self_h.sum() * self_c.sum()).backward() + (torch_h.sum() * torch_c.sum()).backward() + + assert_allclose(x.grad, x_clone.grad, atol=1e-5) + + +def test_layernorm_lstm_layer_jit(device="cpu"): + input_size = 10 + hidden_size = 20 + layer = LayerNormLSTMLayer( + input_size, + hidden_size=hidden_size, + device=device, + ) + torch.jit.script(layer) + + +def test_layernorm_lstm_layer_with_project_jit(device="cpu"): + input_size = 10 + hidden_size = 20 + proj_size = 5 + layer = LayerNormLSTMLayer( + input_size, + hidden_size=hidden_size, + proj_size=proj_size, + device=device, + ) + torch.jit.script(layer) + + +def test_layernorm_lstm_layer_with_projection_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=10, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() + + self_layer = LayerNormLSTMLayer( + input_size, + hidden_size, + bias=bias, + proj_size=proj_size, + ln=nn.Identity, + device=device, + ) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size, device=device).requires_grad_() + h = torch.rand(N, proj_size, device=device) + c = torch.rand(N, hidden_size, device=device) + + x_clone = x.detach().clone().requires_grad_() + + self_y, (self_h, self_c) = self_layer(x, (h, c)) + + torch_layer = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=bias, + proj_size=proj_size, + batch_first=True, + dropout=0, + bidirectional=False, + ).to(device) + with torch.no_grad(): + for name, self_param in self_layer.cell.named_parameters(): + getattr(torch_layer, f"{name}_l0").copy_(self_param) + + torch_y, (torch_h, torch_c) = torch_layer( + x_clone, (h.unsqueeze(0), c.unsqueeze(0)) + ) + assert_allclose(self_y, torch_y) + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + self_y.sum().backward() + torch_y.sum().backward() + + assert_allclose(x.grad, x_clone.grad, atol=1e-5) + + +def test_layernorm_lstm_layer_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + self_layer = LayerNormLSTMLayer( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + device=device, + ) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size, device=device).requires_grad_() + h = torch.rand(N, hidden_size, device=device) + c = torch.rand(N, hidden_size, device=device) + + x_clone = x.detach().clone().requires_grad_() + + self_y, (self_h, self_c) = self_layer(x, (h, c)) + + torch_layer = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=bias, + batch_first=True, + dropout=0, + bidirectional=False, + ).to(device) + with torch.no_grad(): + for name, self_param in self_layer.cell.named_parameters(): + getattr(torch_layer, f"{name}_l0").copy_(self_param) + + torch_y, (torch_h, torch_c) = torch_layer( + x_clone, (h.unsqueeze(0), c.unsqueeze(0)) + ) + assert_allclose(self_y, torch_y) + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + self_hc = self_h * self_c + torch_hc = torch_h * torch_c + self_hc_sum = ( + self_hc.reshape(-1) * torch.arange(self_hc.numel(), device=device) + ).sum() + torch_hc_sum = ( + torch_hc.reshape(-1) * torch.arange(torch_hc.numel(), device=device) + ).sum() + + self_y_sum = ( + self_y.reshape(-1) * torch.arange(self_y.numel(), device=device) + ).sum() + torch_y_sum = ( + torch_y.reshape(-1) * torch.arange(torch_y.numel(), device=device) + ).sum() + + (self_hc_sum + self_y_sum).backward() + (torch_hc_sum + torch_y_sum).backward() + + assert_allclose(x.grad, x_clone.grad, atol=0.1) + + +def test_layernorm_lstm_jit(device="cpu"): + input_size = 2 + hidden_size = 3 + num_layers = 4 + bias = True + + lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + device=device, + ) + torch.jit.script(lstm) + + +def test_layernorm_lstm_with_projection_jit(device="cpu"): + input_size = 2 + hidden_size = 5 + proj_size = 3 + num_layers = 4 + bias = True + + lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + proj_size=proj_size, + ln=nn.Identity, + device=device, + ) + torch.jit.script(lstm) + + +def test_layernorm_lstm_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + device=device, + ) + torch_lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=True, + bidirectional=False, + ).to(device) + assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict()) + with torch.no_grad(): + for name, param in self_lstm.named_parameters(): + # name has the form layers.0.cell.weight_hh + parts = name.split(".") + layer_num = parts[1] + getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size, device=device).requires_grad_() + hs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)] + cs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)] + states = list(zip(hs, cs)) + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_states = self_lstm(x, states) + + h = torch.stack(hs) + c = torch.stack(cs) + torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c)) + + assert_allclose(self_y, torch_y) + + self_h = torch.stack([s[0] for s in self_states]) + self_c = torch.stack([s[1] for s in self_states]) + + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + s = self_y.reshape(-1) + t = torch_y.reshape(-1) + + s_sum = (s * torch.arange(s.numel(), device=device)).sum() + t_sum = (t * torch.arange(t.numel(), device=device)).sum() + shc_sum = s_sum + self_h.sum() + self_c.sum() + thc_sum = t_sum + torch_h.sum() + torch_c.sum() + + shc_sum.backward() + thc_sum.backward() + assert_allclose(x.grad, x_clone.grad) + + +def test_layernorm_lstm_with_projection_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=10, high=100, size=(1,)).item() + proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + proj_size=proj_size, + ln=nn.Identity, + device=device, + ) + torch_lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + proj_size=proj_size, + batch_first=True, + bidirectional=False, + ).to(device) + assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict()) + with torch.no_grad(): + for name, param in self_lstm.named_parameters(): + # name has the form layers.0.cell.weight_hh + parts = name.split(".") + layer_num = parts[1] + getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size, device=device).requires_grad_() + hs = [torch.rand(N, proj_size, device=device) for _ in range(num_layers)] + cs = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)] + states = list(zip(hs, cs)) + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_states = self_lstm(x, states) + + h = torch.stack(hs) + c = torch.stack(cs) + torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c)) + + assert_allclose(self_y, torch_y) + + self_h = torch.stack([s[0] for s in self_states]) + self_c = torch.stack([s[1] for s in self_states]) + + assert_allclose(self_h, torch_h) + assert_allclose(self_c, torch_c) + + s = self_y.reshape(-1) + t = torch_y.reshape(-1) + + s_sum = (s * torch.arange(s.numel(), device=device)).sum() + t_sum = (t * torch.arange(t.numel(), device=device)).sum() + shc_sum = s_sum + self_h.sum() + self_c.sum() + thc_sum = t_sum + torch_h.sum() + torch_c.sum() + + shc_sum.backward() + thc_sum.backward() + assert_allclose(x.grad, x_clone.grad) + + +def test_layernorm_gru_cell_jit(device="cpu"): + input_size = 10 + hidden_size = 20 + cell = LayerNormGRUCell( + input_size=input_size, + hidden_size=hidden_size, + bias=True, + device=device, + ) + + torch.jit.script(cell) + + +def test_layernorm_gru_cell_constructor(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + + self_cell = LayerNormGRUCell( + input_size, + hidden_size, + ln=nn.Identity, + device=device, + ) + torch_cell = nn.GRUCell( + input_size, + hidden_size, + ).to(device) + + for name, param in self_cell.named_parameters(): + assert param.shape == getattr(torch_cell, name).shape + + assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) + + +def test_layernorm_gru_cell_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_cell = LayerNormGRUCell( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + device=device, + ) + torch_cell = nn.GRUCell( + input_size, + hidden_size, + bias=bias, + ).to(device) + with torch.no_grad(): + for name, torch_param in torch_cell.named_parameters(): + self_param = getattr(self_cell, name) + torch_param.copy_(self_param) + + N = torch.randint(low=2, high=100, size=(1,)) + x = torch.rand(N, input_size, device=device).requires_grad_() + h = torch.rand(N, hidden_size, device=device) + + x_clone = x.detach().clone().requires_grad_() + + self_h = self_cell(x.clone(), h) + torch_h = torch_cell(x_clone, h) + + assert_allclose(self_h, torch_h, atol=1e-5) + + ( + self_h.reshape(-1) * torch.arange(self_h.numel(), device=device) + ).sum().backward() + ( + torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device) + ).sum().backward() + + assert_allclose(x.grad, x_clone.grad, atol=1e-3) + + +def test_layernorm_gru_layer_jit(device="cpu"): + input_size = 10 + hidden_size = 20 + layer = LayerNormGRULayer( + input_size, + hidden_size=hidden_size, + device=device, + ) + torch.jit.script(layer) + + +def test_layernorm_gru_layer_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + self_layer = LayerNormGRULayer( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + device=device, + ) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size, device=device).requires_grad_() + h = torch.rand(N, hidden_size, device=device) + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_h = self_layer(x, h.clone()) + + torch_layer = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=bias, + batch_first=True, + dropout=0, + bidirectional=False, + ).to(device) + with torch.no_grad(): + for name, self_param in self_layer.cell.named_parameters(): + getattr(torch_layer, f"{name}_l0").copy_(self_param) + + torch_y, torch_h = torch_layer(x_clone, h.unsqueeze(0)) + assert_allclose(self_y, torch_y) + assert_allclose(self_h, torch_h) + + self_y_sum = ( + self_y.reshape(-1) * torch.arange(self_y.numel(), device=device) + ).sum() + torch_y_sum = ( + torch_y.reshape(-1) * torch.arange(torch_y.numel(), device=device) + ).sum() + + self_y_sum.backward() + torch_y_sum.backward() + + assert_allclose(x.grad, x_clone.grad, atol=0.1) + + +def test_layernorm_gru_jit(device="cpu"): + input_size = 2 + hidden_size = 3 + num_layers = 4 + bias = True + + gru = LayerNormGRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + device=device, + ) + torch.jit.script(gru) + + +def test_layernorm_gru_forward(device="cpu"): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_gru = LayerNormGRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + device=device, + ) + torch_gru = nn.GRU( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=True, + bidirectional=False, + ).to(device) + assert len(self_gru.state_dict()) == len(torch_gru.state_dict()) + with torch.no_grad(): + for name, param in self_gru.named_parameters(): + # name has the form layers.0.cell.weight_hh + parts = name.split(".") + layer_num = parts[1] + getattr(torch_gru, f"{parts[-1]}_l{layer_num}").copy_(param) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size, device=device).requires_grad_() + states = [ + torch.rand(N, hidden_size, device=device) for _ in range(num_layers) + ] + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_states = self_gru(x, states) + + torch_y, torch_states = torch_gru(x_clone, torch.stack(states)) + + assert_allclose(self_y, torch_y) + + self_states = torch.stack(self_states) + + assert_allclose(self_states, torch_states) + + s = self_y.reshape(-1) + t = torch_y.reshape(-1) + + s_sum = (s * torch.arange(s.numel(), device=device)).sum() + t_sum = (t * torch.arange(t.numel(), device=device)).sum() + s_state_sum = s_sum + self_states.sum() + t_state_sum = t_sum + torch_states.sum() + + s_state_sum.backward() + t_state_sum.backward() + assert_allclose(x.grad, x_clone.grad, atol=1e-2) + + +def _test_lstm(device): + test_layernorm_lstm_cell_jit(device) + test_layernorm_lstm_cell_constructor(device) + test_layernorm_lstm_cell_with_projection_jit(device) + test_layernorm_lstm_cell_forward(device) + test_layernorm_lstm_cell_with_projection_forward(device) + # + test_layernorm_lstm_layer_jit(device) + test_layernorm_lstm_layer_with_project_jit(device) + test_layernorm_lstm_layer_forward(device) + test_layernorm_lstm_layer_with_projection_forward(device) + + test_layernorm_lstm_jit(device) + test_layernorm_lstm_with_projection_jit(device) + test_layernorm_lstm_forward(device) + test_layernorm_lstm_with_projection_forward(device) + + +def _test_gru(device): + test_layernorm_gru_cell_jit(device) + test_layernorm_gru_cell_constructor(device) + test_layernorm_gru_cell_forward(device) + # + test_layernorm_gru_layer_jit(device) + test_layernorm_gru_layer_forward(device) + # + test_layernorm_gru_jit(device) + test_layernorm_gru_forward(device) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def main(): + for device in get_devices(): + print("device", device) + _test_lstm(device) + _test_gru(device) + + +if __name__ == "__main__": + torch.manual_seed(20211202) + main() diff --git a/requirements.txt b/requirements.txt index 710048fed..4eaa86a67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ kaldilm kaldialign sentencepiece>=0.1.96 tensorboard +typeguard