mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Apply layer normalization to the output of each gate in LSTM.
This commit is contained in:
parent
89b84208aa
commit
8a038b8f1a
7
.github/workflows/test.yml
vendored
7
.github/workflows/test.yml
vendored
@ -113,6 +113,11 @@ jobs:
|
||||
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
||||
pytest -v -s ./test
|
||||
|
||||
# runt tests for conformer ctc
|
||||
pwd=$PWD
|
||||
# run tests for conformer ctc
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
pytest -v -s
|
||||
|
||||
cd $PWD
|
||||
cd egs/librispeech/ASR
|
||||
pytest -v -s ./rnnt
|
||||
|
0
egs/librispeech/ASR/rnnt/__init__.py
Normal file
0
egs/librispeech/ASR/rnnt/__init__.py
Normal file
284
egs/librispeech/ASR/rnnt/rnn.py
Normal file
284
egs/librispeech/ASR/rnnt/rnn.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Apply layer normalization to the output of each gate in LSTM/GRU.
|
||||
|
||||
This file uses
|
||||
https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
|
||||
as a reference.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# TODO(fangjun): Support projection, see https://arxiv.org/pdf/1402.1128.pdf
|
||||
class LayerNormLSTMCell(nn.Module):
|
||||
"""This class places a `nn.LayerNorm` after the output of
|
||||
each gate (right before the activation).
|
||||
|
||||
See the following paper for more details
|
||||
|
||||
'Improving RNN Transducer Modeling for End-to-End Speech Recognition'
|
||||
https://arxiv.org/abs/1909.12415
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_size:
|
||||
The number of expected features in the input `x`. `x` should
|
||||
be of shape (batch_size, input_size).
|
||||
hidden_size:
|
||||
The number of features in the hidden state `h` and `c`.
|
||||
Both `h` and `c` are of shape (batch_size, hidden_size).
|
||||
bias:
|
||||
If ``False``, then the cell does not use bias weights
|
||||
`bias_ih` and `bias_hh`.
|
||||
ln:
|
||||
Defaults to `nn.LayerNorm`. The output of all gates are processed
|
||||
by `ln`. We pass it as an argument so that we can replace it
|
||||
with `nn.Identity` at the testing time.
|
||||
"""
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.bias = bias
|
||||
|
||||
self.weight_ih = nn.Parameter(
|
||||
torch.empty((4 * hidden_size, input_size), **factory_kwargs)
|
||||
)
|
||||
|
||||
self.weight_hh = nn.Parameter(
|
||||
torch.empty((4 * hidden_size, hidden_size), **factory_kwargs)
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias_ih = nn.Parameter(
|
||||
torch.empty(4 * hidden_size, **factory_kwargs)
|
||||
)
|
||||
self.bias_hh = nn.Parameter(
|
||||
torch.empty(4 * hidden_size, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias_ih", None)
|
||||
self.register_parameter("bias_hh", None)
|
||||
|
||||
self.layernorm_i = ln(hidden_size)
|
||||
self.layernorm_f = ln(hidden_size)
|
||||
self.layernorm_cx = ln(hidden_size)
|
||||
self.layernorm_cy = ln(hidden_size)
|
||||
self.layernorm_o = ln(hidden_size)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
input:
|
||||
A 2-D tensor of shape (batch_size, input_size).
|
||||
state:
|
||||
If not ``None``, it contains the hidden state (h, c); both
|
||||
are of shape (batch_size, hidden_size). If ``None``, it uses
|
||||
zeros for `h` and `c`.
|
||||
Returns:
|
||||
Return two tensors:
|
||||
- `next_h`: It is of shape (batch_size, hidden_size) containing the
|
||||
next hidden state for each element in the batch.
|
||||
- `next_c`: It is of shape (batch_size, hidden_size) containing the
|
||||
next cell state for each element in the batch.
|
||||
"""
|
||||
if state is None:
|
||||
zeros = torch.zeros(
|
||||
input.size(0),
|
||||
self.hidden_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
state = (zeros, zeros)
|
||||
|
||||
hx, cx = state
|
||||
gates = F.linear(input, self.weight_ih, self.bias_ih) + F.linear(
|
||||
hx, self.weight_hh, self.bias_hh
|
||||
)
|
||||
|
||||
in_gate, forget_gate, cell_gate, out_gate = gates.chunk(chunks=4, dim=1)
|
||||
|
||||
in_gate = self.layernorm_i(in_gate)
|
||||
forget_gate = self.layernorm_f(forget_gate)
|
||||
cell_gate = self.layernorm_cx(cell_gate)
|
||||
out_gate = self.layernorm_o(out_gate)
|
||||
|
||||
in_gate = torch.sigmoid(in_gate)
|
||||
forget_gate = torch.sigmoid(forget_gate)
|
||||
cell_gate = torch.tanh(cell_gate)
|
||||
out_gate = torch.sigmoid(out_gate)
|
||||
|
||||
cy = (forget_gate * cx) + (in_gate * cell_gate)
|
||||
cy = self.layernorm_cy(cy)
|
||||
hy = out_gate * torch.tanh(cy)
|
||||
|
||||
return hy, cy
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = "{input_size}, {hidden_size}"
|
||||
if "bias" in self.__dict__ and self.bias is not True:
|
||||
s += ", bias={bias}"
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
stdv = 1.0 / math.sqrt(self.hidden_size)
|
||||
for weight in self.parameters():
|
||||
nn.init.uniform_(weight, -stdv, stdv)
|
||||
|
||||
|
||||
class LayerNormLSTMLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
bias: bool = True,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
See the args in LayerNormLSTMCell
|
||||
"""
|
||||
super().__init__()
|
||||
self.cell = LayerNormLSTMCell(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
bias=bias,
|
||||
ln=ln,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
state: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
input:
|
||||
A 3-D tensor of shape (batch_size, seq_len, input_size).
|
||||
Caution:
|
||||
We use `batch_first=True` here.
|
||||
state:
|
||||
If not ``None``, it contains the hidden state (h, c) of this layer.
|
||||
Both are of shape (batch_size, hidden_size).
|
||||
Note:
|
||||
We did not annotate `state` with `Optional[Tuple[...]]` since
|
||||
torchscript will complain.
|
||||
Return:
|
||||
- output, a tensor of shape (batch_size, seq_len, hidden_size)
|
||||
- (next_h, next_c) containing the hidden state of this layer
|
||||
"""
|
||||
inputs = input.unbind(1)
|
||||
outputs = torch.jit.annotate(List[torch.Tensor], [])
|
||||
for i in range(len(inputs)):
|
||||
state = self.cell(inputs[i], state)
|
||||
outputs += [state[0]]
|
||||
return torch.stack(outputs, dim=1), state
|
||||
|
||||
|
||||
class LayerNormLSTM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
num_layers: int,
|
||||
bias: bool = True,
|
||||
ln: nn.Module = nn.LayerNorm,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
See the args in LSTMLayer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert num_layers >= 1
|
||||
factory_kwargs = dict(
|
||||
hidden_size=hidden_size,
|
||||
bias=bias,
|
||||
ln=ln,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
first_layer = LayerNormLSTMLayer(
|
||||
input_size=input_size, **factory_kwargs
|
||||
)
|
||||
layers = [first_layer]
|
||||
for i in range(1, num_layers):
|
||||
layers.append(
|
||||
LayerNormLSTMLayer(
|
||||
input_size=hidden_size,
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
states: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
input:
|
||||
A 3-D tensor of shape (batch_size, seq_len, input_size).
|
||||
Caution:
|
||||
We use `batch_first=True` here.
|
||||
states:
|
||||
One state per layer. Each entry contains the hidden state (h, c)
|
||||
for a layer. Both are of shape (batch_size, hidden_size).
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
|
||||
- output: A tensor of shape (batch_size, seq_len, hidden_size)
|
||||
- List[(next_h, next_c)] containing the hidden states for all layers
|
||||
|
||||
"""
|
||||
output_states = torch.jit.annotate(
|
||||
List[Tuple[torch.Tensor, torch.Tensor]], []
|
||||
)
|
||||
output = input
|
||||
for i, rnn_layer in enumerate(self.layers):
|
||||
state = states[i]
|
||||
output, out_state = rnn_layer(output, state)
|
||||
output_states += [out_state]
|
||||
return output, output_states
|
240
egs/librispeech/ASR/rnnt/test_rnn.py
Executable file
240
egs/librispeech/ASR/rnnt/test_rnn.py
Executable file
@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from rnnt.rnn import LayerNormLSTM, LayerNormLSTMCell, LayerNormLSTMLayer
|
||||
|
||||
|
||||
def test_layernorm_lstm_cell_jit():
|
||||
input_size = 10
|
||||
hidden_size = 20
|
||||
cell = LayerNormLSTMCell(
|
||||
input_size=input_size, hidden_size=hidden_size, bias=True
|
||||
)
|
||||
|
||||
torch.jit.script(cell)
|
||||
|
||||
|
||||
def test_layernorm_lstm_cell_constructor():
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
|
||||
self_cell = LayerNormLSTMCell(input_size, hidden_size, ln=nn.Identity)
|
||||
torch_cell = nn.LSTMCell(input_size, hidden_size)
|
||||
|
||||
for name, param in self_cell.named_parameters():
|
||||
assert param.shape == getattr(torch_cell, name).shape
|
||||
|
||||
assert len(self_cell.state_dict()) == len(torch_cell.state_dict())
|
||||
|
||||
|
||||
def test_layernorm_lstm_cell_forward():
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
|
||||
self_cell = LayerNormLSTMCell(
|
||||
input_size, hidden_size, bias=bias, ln=nn.Identity
|
||||
)
|
||||
torch_cell = nn.LSTMCell(input_size, hidden_size, bias=bias)
|
||||
with torch.no_grad():
|
||||
for name, torch_param in torch_cell.named_parameters():
|
||||
self_param = getattr(self_cell, name)
|
||||
torch_param.copy_(self_param)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
x = torch.rand(N, input_size).requires_grad_()
|
||||
h = torch.rand(N, hidden_size)
|
||||
c = torch.rand(N, hidden_size)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
self_h, self_c = self_cell(x.clone(), (h, c))
|
||||
torch_h, torch_c = torch_cell(x_clone, (h, c))
|
||||
|
||||
assert torch.allclose(self_h, torch_h)
|
||||
assert torch.allclose(self_c, torch_c)
|
||||
|
||||
self_hc = self_h * self_c
|
||||
torch_hc = torch_h * torch_c
|
||||
(self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum().backward()
|
||||
(torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum().backward()
|
||||
|
||||
assert torch.allclose(x.grad, x_clone.grad)
|
||||
|
||||
|
||||
def test_lstm_layer_jit():
|
||||
input_size = 10
|
||||
hidden_size = 20
|
||||
layer = LayerNormLSTMLayer(input_size, hidden_size=hidden_size)
|
||||
torch.jit.script(layer)
|
||||
|
||||
|
||||
def test_lstm_layer_forward():
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
self_layer = LayerNormLSTMLayer(
|
||||
input_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
T = torch.randint(low=2, high=100, size=(1,))
|
||||
|
||||
x = torch.rand(N, T, input_size).requires_grad_()
|
||||
h = torch.rand(N, hidden_size)
|
||||
c = torch.rand(N, hidden_size)
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
self_y, (self_h, self_c) = self_layer(x, (h, c))
|
||||
|
||||
# now for pytorch
|
||||
torch_layer = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=1,
|
||||
bias=bias,
|
||||
batch_first=True,
|
||||
dropout=0,
|
||||
bidirectional=False,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for name, self_param in self_layer.cell.named_parameters():
|
||||
getattr(torch_layer, f"{name}_l0").copy_(self_param)
|
||||
|
||||
torch_y, (torch_h, torch_c) = torch_layer(
|
||||
x_clone, (h.unsqueeze(0), c.unsqueeze(0))
|
||||
)
|
||||
assert torch.allclose(self_y, torch_y)
|
||||
assert torch.allclose(self_h, torch_h)
|
||||
assert torch.allclose(self_c, torch_c)
|
||||
|
||||
self_hc = self_h * self_c
|
||||
torch_hc = torch_h * torch_c
|
||||
self_hc_sum = (self_hc.reshape(-1) * torch.arange(self_hc.numel())).sum()
|
||||
torch_hc_sum = (torch_hc.reshape(-1) * torch.arange(torch_hc.numel())).sum()
|
||||
|
||||
self_y_sum = (self_y.reshape(-1) * torch.arange(self_y.numel())).sum()
|
||||
torch_y_sum = (torch_y.reshape(-1) * torch.arange(torch_y.numel())).sum()
|
||||
|
||||
(self_hc_sum * self_y_sum).backward()
|
||||
(torch_hc_sum * torch_y_sum).backward()
|
||||
|
||||
assert torch.allclose(x.grad, x_clone.grad, rtol=0.1)
|
||||
|
||||
|
||||
def test_stacked_lstm_jit():
|
||||
input_size = 2
|
||||
hidden_size = 3
|
||||
num_layers = 4
|
||||
bias = True
|
||||
|
||||
lstm = LayerNormLSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
)
|
||||
torch.jit.script(lstm)
|
||||
|
||||
|
||||
def test_stacked_lstm_forward():
|
||||
input_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
hidden_size = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
num_layers = torch.randint(low=2, high=100, size=(1,)).item()
|
||||
bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0
|
||||
|
||||
self_lstm = LayerNormLSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
bias=bias,
|
||||
ln=nn.Identity,
|
||||
)
|
||||
torch_lstm = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
bias=bias,
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
)
|
||||
assert len(self_lstm.state_dict()) == len(torch_lstm.state_dict())
|
||||
with torch.no_grad():
|
||||
for name, param in self_lstm.named_parameters():
|
||||
# name has the form layers.0.cell.weight_hh
|
||||
parts = name.split(".")
|
||||
layer_num = parts[1]
|
||||
getattr(torch_lstm, f"{parts[-1]}_l{layer_num}").copy_(param)
|
||||
|
||||
N = torch.randint(low=2, high=100, size=(1,))
|
||||
T = torch.randint(low=2, high=100, size=(1,))
|
||||
|
||||
x = torch.rand(N, T, input_size).requires_grad_()
|
||||
hs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||
cs = [torch.rand(N, hidden_size) for _ in range(num_layers)]
|
||||
states = list(zip(hs, cs))
|
||||
|
||||
x_clone = x.detach().clone().requires_grad_()
|
||||
|
||||
self_y, self_states = self_lstm(x, states)
|
||||
|
||||
h = torch.stack(hs)
|
||||
c = torch.stack(cs)
|
||||
torch_y, (torch_h, torch_c) = torch_lstm(x_clone, (h, c))
|
||||
|
||||
assert torch.allclose(self_y, torch_y)
|
||||
|
||||
self_h = torch.stack([s[0] for s in self_states])
|
||||
self_c = torch.stack([s[1] for s in self_states])
|
||||
|
||||
assert torch.allclose(self_h, torch_h)
|
||||
assert torch.allclose(self_c, torch_c)
|
||||
|
||||
s = self_y.reshape(-1)
|
||||
t = torch_y.reshape(-1)
|
||||
|
||||
s_sum = (s * torch.arange(s.numel())).sum()
|
||||
t_sum = (t * torch.arange(t.numel())).sum()
|
||||
shc_sum = s_sum * self_h.sum() * self_c.sum()
|
||||
thc_sum = t_sum * torch_h.sum() * torch_c.sum()
|
||||
|
||||
shc_sum.backward()
|
||||
thc_sum.backward()
|
||||
assert torch.allclose(x.grad, x_clone.grad)
|
||||
|
||||
|
||||
def main():
|
||||
test_layernorm_lstm_cell_jit()
|
||||
test_layernorm_lstm_cell_constructor()
|
||||
test_layernorm_lstm_cell_forward()
|
||||
#
|
||||
test_lstm_layer_jit()
|
||||
test_lstm_layer_forward()
|
||||
#
|
||||
test_stacked_lstm_jit()
|
||||
test_stacked_lstm_forward()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user