diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e897c3fb5..a9e4ff2b7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -113,6 +113,11 @@ jobs: export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH pytest -v -s ./test - # runt tests for conformer ctc + pwd=$PWD + # run tests for conformer ctc cd egs/librispeech/ASR/conformer_ctc pytest -v -s + + cd $PWD + cd egs/librispeech/ASR + pytest -v -s ./rnnt diff --git a/egs/librispeech/ASR/rnnt/__init__.py b/egs/librispeech/ASR/rnnt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/rnnt/rnn.py b/egs/librispeech/ASR/rnnt/rnn.py new file mode 100644 index 000000000..31849225b --- /dev/null +++ b/egs/librispeech/ASR/rnnt/rnn.py @@ -0,0 +1,284 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Apply layer normalization to the output of each gate in LSTM/GRU. + +This file uses +https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py +as a reference. +""" + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf +class LayerNormLSTMCell(nn.Module): + """This class places a `nn.LayerNorm` after the output of + each gate (right before the activation). + + See the following paper for more details + + 'Improving RNN Transducer Modeling for End-to-End Speech Recognition' + https://arxiv.org/abs/1909.12415 + """ + + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + Args: + input_size: + The number of expected features in the input `x`. `x` should + be of shape (batch_size, input_size). + hidden_size: + The number of features in the hidden state `h` and `c`. + Both `h` and `c` are of shape (batch_size, hidden_size). + bias: + If ``False``, then the cell does not use bias weights + `bias_ih` and `bias_hh`. + ln: + Defaults to `nn.LayerNorm`. The output of all gates are processed + by `ln`. We pass it as an argument so that we can replace it + with `nn.Identity` at the testing time. + """ + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.input_size = input_size + self.hidden_size = hidden_size + self.bias = bias + + self.weight_ih = nn.Parameter( + torch.empty((4 * hidden_size, input_size), **factory_kwargs) + ) + + self.weight_hh = nn.Parameter( + torch.empty((4 * hidden_size, hidden_size), **factory_kwargs) + ) + + if bias: + self.bias_ih = nn.Parameter( + torch.empty(4 * hidden_size, **factory_kwargs) + ) + self.bias_hh = nn.Parameter( + torch.empty(4 * hidden_size, **factory_kwargs) + ) + else: + self.register_parameter("bias_ih", None) + self.register_parameter("bias_hh", None) + + self.layernorm_i = ln(hidden_size) + self.layernorm_f = ln(hidden_size) + self.layernorm_cx = ln(hidden_size) + self.layernorm_cy = ln(hidden_size) + self.layernorm_o = ln(hidden_size) + + self.reset_parameters() + + def forward( + self, + input: torch.Tensor, + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input: + A 2-D tensor of shape (batch_size, input_size). + state: + If not ``None``, it contains the hidden state (h, c); both + are of shape (batch_size, hidden_size). If ``None``, it uses + zeros for `h` and `c`. + Returns: + Return two tensors: + - `next_h`: It is of shape (batch_size, hidden_size) containing the + next hidden state for each element in the batch. + - `next_c`: It is of shape (batch_size, hidden_size) containing the + next cell state for each element in the batch. + """ + if state is None: + zeros = torch.zeros( + input.size(0), + self.hidden_size, + dtype=input.dtype, + device=input.device, + ) + state = (zeros, zeros) + + hx, cx = state + gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear( + hx, self.weight_hh, self.bias_hh + ) + + in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1) + + in_gate = self.layernorm_i(in_gate) + forget_gate = self.layernorm_f(forget_gate) + cell_gate = self.layernorm_cx(cell_gate) + out_gate = self.layernorm_o(out_gate) + + in_gate = torch.sigmoid(in_gate) + forget_gate = torch.sigmoid(forget_gate) + cell_gate = torch.tanh(cell_gate) + out_gate = torch.sigmoid(out_gate) + + cy = (forget_gate * cx) + (in_gate * cell_gate) + cy = self.layernorm_cy(cy) + hy = out_gate * torch.tanh(cy) + + return hy, cy + + def extra_repr(self) -> str: + s = "{input_size}, {hidden_size}" + if "bias" in self.__dict__ and self.bias is not True: + s += ", bias={bias}" + return s.format(**self.__dict__) + + def reset_parameters(self) -> None: + stdv = 1.0 / math.sqrt(self.hidden_size) + for weight in self.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + + +class LayerNormLSTMLayer(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LayerNormLSTMCell + """ + super().__init__() + self.cell = LayerNormLSTMCell( + input_size=input_size, + hidden_size=hidden_size, + bias=bias, + ln=ln, + device=device, + dtype=dtype, + ) + + def forward( + self, + input: torch.Tensor, + state: Tuple[torch.Tensor, torch.Tensor], + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + input: + A 3-D tensor of shape (batch_size, seq_len, input_size). + Caution: + We use `batch_first=True` here. + state: + If not ``None``, it contains the hidden state (h, c) of this layer. + Both are of shape (batch_size, hidden_size). + Note: + We did not annotate `state` with `Optional[Tuple[...]]` since + torchscript will complain. + Return: + - output, a tensor of shape (batch_size, seq_len, hidden_size) + - (next_h, next_c) containing the hidden state of this layer + """ + inputs = input.unbind(1) + outputs = torch.jit.annotate(List[torch.Tensor], []) + for i in range(len(inputs)): + state = self.cell(inputs[i], state) + outputs += [state[0]] + return torch.stack(outputs, dim=1), state + + +class LayerNormLSTM(nn.Module): + def __init__( + self, + input_size: int, + hidden_size: int, + num_layers: int, + bias: bool = True, + ln: nn.Module = nn.LayerNorm, + device=None, + dtype=None, + ): + """ + See the args in LSTMLayer. + """ + super().__init__() + assert num_layers >= 1 + factory_kwargs = dict( + hidden_size=hidden_size, + bias=bias, + ln=ln, + device=device, + dtype=dtype, + ) + first_layer = LayerNormLSTMLayer( + input_size=input_size, **factory_kwargs + ) + layers = [first_layer] + for i in range(1, num_layers): + layers.append( + LayerNormLSTMLayer( + input_size=hidden_size, + **factory_kwargs, + ) + ) + self.layers = nn.ModuleList(layers) + self.num_layers = num_layers + + def forward( + self, + input: torch.Tensor, + states: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + """ + Args: + input: + A 3-D tensor of shape (batch_size, seq_len, input_size). + Caution: + We use `batch_first=True` here. + states: + One state per layer. Each entry contains the hidden state (h, c) + for a layer. Both are of shape (batch_size, hidden_size). + Returns: + Return a tuple containing: + + - output: A tensor of shape (batch_size, seq_len, hidden_size) + - List[(next_h, next_c)] containing the hidden states for all layers + + """ + output_states = torch.jit.annotate( + List[Tuple[torch.Tensor, torch.Tensor]], [] + ) + output = input + for i, rnn_layer in enumerate(self.layers): + state = states[i] + output, out_state = rnn_layer(output, state) + output_states += [out_state] + return output, output_states diff --git a/egs/librispeech/ASR/rnnt/test_rnn.py b/egs/librispeech/ASR/rnnt/test_rnn.py new file mode 100755 index 000000000..683183fb5 --- /dev/null +++ b/egs/librispeech/ASR/rnnt/test_rnn.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from rnnt.rnn import LayerNormLSTM, LayerNormLSTMCell, LayerNormLSTMLayer + + +def test_layernorm_lstm_cell_jit(): + input_size = 10 + hidden_size = 20 + cell = LayerNormLSTMCell( + input_size=input_size, hidden_size=hidden_size, bias=True + ) + + torch.jit.script(cell) + + +def test_layernorm_lstm_cell_constructor(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + + self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity) + torch_cell = nn.LSTMCell(input_size, hidden_size) + + for name, param in self_cell.named_parameters(): + assert param.shape == getattr(torch_cell, name).shape + + assert len(self_cell.state_dict()) == len(torch_cell.state_dict()) + + +def test_layernorm_lstm_cell_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_cell = LayerNormLSTMCell( + input_size, hidden_size, bias=bias, ln=nn.Identity + ) + torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias) + with torch.no_grad(): + for name, torch_param in torch_cell.named_parameters(): + self_param = getattr(self_cell, name) + torch_param.copy_(self_param) + + N = torch.randint(low=2, high=100, size=(1,)) + x = torch.rand(N, input_size).requires_grad_() + h = torch.rand(N, hidden_size) + c = torch.rand(N, hidden_size) + + x_clone = x.detach().clone().requires_grad_() + + self_h, self_c = self_cell(x.clone(), (h, c)) + torch_h, torch_c = torch_cell(x_clone, (h, c)) + + assert torch.allclose(self_h, torch_h) + assert torch.allclose(self_c, torch_c) + + self_hc = self_h * self_c + torch_hc = torch_h * torch_c + (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward() + (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward() + + assert torch.allclose(x.grad, x_clone.grad) + + +def test_lstm_layer_jit(): + input_size = 10 + hidden_size = 20 + layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size) + torch.jit.script(layer) + + +def test_lstm_layer_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + self_layer = LayerNormLSTMLayer( + input_size, + hidden_size, + bias=bias, + ln=nn.Identity, + ) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size).requires_grad_() + h = torch.rand(N, hidden_size) + c = torch.rand(N, hidden_size) + + x_clone = x.detach().clone().requires_grad_() + + self_y, (self_h, self_c) = self_layer(x, (h, c)) + + # now for pytorch + torch_layer = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=bias, + batch_first=True, + dropout=0, + bidirectional=False, + ) + with torch.no_grad(): + for name, self_param in self_layer.cell.named_parameters(): + getattr(torch_layer, f"{name}_l0").copy_(self_param) + + torch_y, (torch_h, torch_c) = torch_layer( + x_clone, (h.unsqueeze(0), c.unsqueeze(0)) + ) + assert torch.allclose(self_y, torch_y) + assert torch.allclose(self_h, torch_h) + assert torch.allclose(self_c, torch_c) + + self_hc = self_h * self_c + torch_hc = torch_h * torch_c + self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum() + torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum() + + self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum() + torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum() + + (self_hc_sum * self_y_sum).backward() + (torch_hc_sum * torch_y_sum).backward() + + assert torch.allclose(x.grad, x_clone.grad, rtol=0.1) + + +def test_stacked_lstm_jit(): + input_size = 2 + hidden_size = 3 + num_layers = 4 + bias = True + + lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + ) + torch.jit.script(lstm) + + +def test_stacked_lstm_forward(): + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=2, high=100, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + + self_lstm = LayerNormLSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + ln=nn.Identity, + ) + torch_lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + bias=bias, + batch_first=True, + bidirectional=False, + ) + assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict()) + with torch.no_grad(): + for name, param in self_lstm.named_parameters(): + # name has the form layers.0.cell.weight_hh + parts = name.split(".") + layer_num = parts[1] + getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param) + + N = torch.randint(low=2, high=100, size=(1,)) + T = torch.randint(low=2, high=100, size=(1,)) + + x = torch.rand(N, T, input_size).requires_grad_() + hs = [torch.rand(N, hidden_size) for _ in range(num_layers)] + cs = [torch.rand(N, hidden_size) for _ in range(num_layers)] + states = list(zip(hs, cs)) + + x_clone = x.detach().clone().requires_grad_() + + self_y, self_states = self_lstm(x, states) + + h = torch.stack(hs) + c = torch.stack(cs) + torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c)) + + assert torch.allclose(self_y, torch_y) + + self_h = torch.stack([s[0] for s in self_states]) + self_c = torch.stack([s[1] for s in self_states]) + + assert torch.allclose(self_h, torch_h) + assert torch.allclose(self_c, torch_c) + + s = self_y.reshape(-1) + t = torch_y.reshape(-1) + + s_sum = (s * torch.arange(s.numel())).sum() + t_sum = (t * torch.arange(t.numel())).sum() + shc_sum = s_sum * self_h.sum() * self_c.sum() + thc_sum = t_sum * torch_h.sum() * torch_c.sum() + + shc_sum.backward() + thc_sum.backward() + assert torch.allclose(x.grad, x_clone.grad) + + +def main(): + test_layernorm_lstm_cell_jit() + test_layernorm_lstm_cell_constructor() + test_layernorm_lstm_cell_forward() + # + test_lstm_layer_jit() + test_lstm_layer_forward() + # + test_stacked_lstm_jit() + test_stacked_lstm_forward() + + +if __name__ == "__main__": + main()