285 lines
9.0 KiB
Python

# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apply layer normalization to the output of each gate in LSTM/GRU.
This file uses
https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
as a reference.
"""
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf
class LayerNormLSTMCell(nn.Module):
"""This class places a `nn.LayerNorm` after the output of
each gate (right before the activation).
See the following paper for more details
'Improving RNN Transducer Modeling for End-to-End Speech Recognition'
https://arxiv.org/abs/1909.12415
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
ln: nn.Module = nn.LayerNorm,
device=None,
dtype=None,
):
"""
Args:
input_size:
The number of expected features in the input `x`. `x` should
be of shape (batch_size, input_size).
hidden_size:
The number of features in the hidden state `h` and `c`.
Both `h` and `c` are of shape (batch_size, hidden_size).
bias:
If ``False``, then the cell does not use bias weights
`bias_ih` and `bias_hh`.
ln:
Defaults to `nn.LayerNorm`. The output of all gates are processed
by `ln`. We pass it as an argument so that we can replace it
with `nn.Identity` at the testing time.
"""
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.weight_ih = nn.Parameter(
torch.empty((4 * hidden_size, input_size), **factory_kwargs)
)
self.weight_hh = nn.Parameter(
torch.empty((4 * hidden_size, hidden_size), **factory_kwargs)
)
if bias:
self.bias_ih = nn.Parameter(
torch.empty(4 * hidden_size, **factory_kwargs)
)
self.bias_hh = nn.Parameter(
torch.empty(4 * hidden_size, **factory_kwargs)
)
else:
self.register_parameter("bias_ih", None)
self.register_parameter("bias_hh", None)
self.layernorm_i = ln(hidden_size)
self.layernorm_f = ln(hidden_size)
self.layernorm_cx = ln(hidden_size)
self.layernorm_cy = ln(hidden_size)
self.layernorm_o = ln(hidden_size)
self.reset_parameters()
def forward(
self,
input: torch.Tensor,
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input:
A 2-D tensor of shape (batch_size, input_size).
state:
If not ``None``, it contains the hidden state (h, c); both
are of shape (batch_size, hidden_size). If ``None``, it uses
zeros for `h` and `c`.
Returns:
Return two tensors:
- `next_h`: It is of shape (batch_size, hidden_size) containing the
next hidden state for each element in the batch.
- `next_c`: It is of shape (batch_size, hidden_size) containing the
next cell state for each element in the batch.
"""
if state is None:
zeros = torch.zeros(
input.size(0),
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
state = (zeros, zeros)
hx, cx = state
gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear(
hx, self.weight_hh, self.bias_hh
)
in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1)
in_gate = self.layernorm_i(in_gate)
forget_gate = self.layernorm_f(forget_gate)
cell_gate = self.layernorm_cx(cell_gate)
out_gate = self.layernorm_o(out_gate)
in_gate = torch.sigmoid(in_gate)
forget_gate = torch.sigmoid(forget_gate)
cell_gate = torch.tanh(cell_gate)
out_gate = torch.sigmoid(out_gate)
cy = (forget_gate * cx) + (in_gate * cell_gate)
cy = self.layernorm_cy(cy)
hy = out_gate * torch.tanh(cy)
return hy, cy
def extra_repr(self) -> str:
s = "{input_size}, {hidden_size}"
if "bias" in self.__dict__ and self.bias is not True:
s += ", bias={bias}"
return s.format(**self.__dict__)
def reset_parameters(self) -> None:
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)
class LayerNormLSTMLayer(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
ln: nn.Module = nn.LayerNorm,
device=None,
dtype=None,
):
"""
See the args in LayerNormLSTMCell
"""
super().__init__()
self.cell = LayerNormLSTMCell(
input_size=input_size,
hidden_size=hidden_size,
bias=bias,
ln=ln,
device=device,
dtype=dtype,
)
def forward(
self,
input: torch.Tensor,
state: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
input:
A 3-D tensor of shape (batch_size, seq_len, input_size).
Caution:
We use `batch_first=True` here.
state:
If not ``None``, it contains the hidden state (h, c) of this layer.
Both are of shape (batch_size, hidden_size).
Note:
We did not annotate `state` with `Optional[Tuple[...]]` since
torchscript will complain.
Return:
- output, a tensor of shape (batch_size, seq_len, hidden_size)
- (next_h, next_c) containing the hidden state of this layer
"""
inputs = input.unbind(1)
outputs = torch.jit.annotate(List[torch.Tensor], [])
for i in range(len(inputs)):
state = self.cell(inputs[i], state)
outputs += [state[0]]
return torch.stack(outputs, dim=1), state
class LayerNormLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bias: bool = True,
ln: nn.Module = nn.LayerNorm,
device=None,
dtype=None,
):
"""
See the args in LSTMLayer.
"""
super().__init__()
assert num_layers >= 1
factory_kwargs = dict(
hidden_size=hidden_size,
bias=bias,
ln=ln,
device=device,
dtype=dtype,
)
first_layer = LayerNormLSTMLayer(
input_size=input_size, **factory_kwargs
)
layers = [first_layer]
for i in range(1, num_layers):
layers.append(
LayerNormLSTMLayer(
input_size=hidden_size,
**factory_kwargs,
)
)
self.layers = nn.ModuleList(layers)
self.num_layers = num_layers
def forward(
self,
input: torch.Tensor,
states: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Args:
input:
A 3-D tensor of shape (batch_size, seq_len, input_size).
Caution:
We use `batch_first=True` here.
states:
One state per layer. Each entry contains the hidden state (h, c)
for a layer. Both are of shape (batch_size, hidden_size).
Returns:
Return a tuple containing:
- output: A tensor of shape (batch_size, seq_len, hidden_size)
- List[(next_h, next_c)] containing the hidden states for all layers
"""
output_states = torch.jit.annotate(
List[Tuple[torch.Tensor, torch.Tensor]], []
)
output = input
for i, rnn_layer in enumerate(self.layers):
state = states[i]
output, out_state = rnn_layer(output, state)
output_states += [out_state]
return output, output_states