simplified code in ScaledLSTM

This commit is contained in:
yaozengwei 2022-07-17 17:07:14 +08:00
parent 5c669b7716
commit 539a9d75d4
2 changed files with 29 additions and 69 deletions

View File

@ -15,7 +15,7 @@
# limitations under the License.
import copy
from typing import Tuple
from typing import Optional, Tuple
import torch
from encoder_interface import EncoderInterface
@ -73,7 +73,7 @@ class RNN(EncoderInterface):
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_layers = num_encoder_layers
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
encoder_layer = RNNEncoderLayer(
@ -119,8 +119,8 @@ class RNN(EncoderInterface):
return x, lengths
@torch.jit.export
def get_init_state(self, device: torch.device) -> torch.Tensor:
"""Get model initial state."""
def get_init_states(self, device: torch.device) -> torch.Tensor:
"""Get model initial states."""
init_states = torch.zeros(
(2, self.num_encoder_layers, self.d_model), device=device
)
@ -283,7 +283,7 @@ class RNNEncoderLayer(nn.Module):
# lstm module
# The required shapes of h_0 and c_0 are both (1, N, E).
src_lstm, new_states = self.lstm(src, states.unbind(dim=0))
src_lstm, new_states = self.lstm(src, (states[0], states[1]))
new_states = torch.stack(new_states, dim=0)
src = src + self.dropout(src_lstm)

View File

@ -22,7 +22,6 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor, _VF
from torch.nn.utils.rnn import PackedSequence
def _ntuple(n):
@ -428,51 +427,30 @@ class ScaledLSTM(nn.LSTM):
)
return flat_weights
def forward(self, input, hx=None):
# This function is copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
# The only change is for calling `_VF.lstm()`:
def forward(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
):
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
# The change for calling `_VF.lstm()` is:
# self._flat_weights -> self.get_flat_weights()
orig_input = input
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
input, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = batch_sizes[0]
max_batch_size = int(max_batch_size)
else:
batch_sizes = None
max_batch_size = (
input.size(0) if self.batch_first else input.size(1)
)
sorted_indices = None
unsorted_indices = None
if hx is None:
num_directions = 2 if self.bidirectional else 1
real_hidden_size = (
self.proj_size if self.proj_size > 0 else self.hidden_size
)
h_zeros = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
real_hidden_size,
1,
input.size(1),
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
c_zeros = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
1,
input.size(1),
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
hx = (h_zeros, c_zeros)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
self.check_forward_args(input, hx, None)
result = _VF.lstm(
input,
hx,
@ -484,28 +462,10 @@ class ScaledLSTM(nn.LSTM):
self.bidirectional,
self.batch_first,
)
else:
result = _VF.lstm(
input,
batch_sizes,
hx,
self.get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,
self.training,
self.bidirectional,
)
output = result[0]
hidden = result[1:]
# xxx: isinstance check needs to be in conditional for TorchScript to compile
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output, batch_sizes, sorted_indices, unsorted_indices
)
return output_packed, self.permute_hidden(hidden, unsorted_indices)
else:
return output, self.permute_hidden(hidden, unsorted_indices)
return output, hidden
class ActivationBalancer(torch.nn.Module):