mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
simplified code in ScaledLSTM
This commit is contained in:
parent
5c669b7716
commit
539a9d75d4
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
@ -73,7 +73,7 @@ class RNN(EncoderInterface):
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder_layers = num_encoder_layers
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.d_model = d_model
|
||||
|
||||
encoder_layer = RNNEncoderLayer(
|
||||
@ -119,8 +119,8 @@ class RNN(EncoderInterface):
|
||||
return x, lengths
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_state(self, device: torch.device) -> torch.Tensor:
|
||||
"""Get model initial state."""
|
||||
def get_init_states(self, device: torch.device) -> torch.Tensor:
|
||||
"""Get model initial states."""
|
||||
init_states = torch.zeros(
|
||||
(2, self.num_encoder_layers, self.d_model), device=device
|
||||
)
|
||||
@ -283,7 +283,7 @@ class RNNEncoderLayer(nn.Module):
|
||||
|
||||
# lstm module
|
||||
# The required shapes of h_0 and c_0 are both (1, N, E).
|
||||
src_lstm, new_states = self.lstm(src, states.unbind(dim=0))
|
||||
src_lstm, new_states = self.lstm(src, (states[0], states[1]))
|
||||
new_states = torch.stack(new_states, dim=0)
|
||||
src = src + self.dropout(src_lstm)
|
||||
|
||||
|
@ -22,7 +22,6 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, _VF
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
@ -428,51 +427,30 @@ class ScaledLSTM(nn.LSTM):
|
||||
)
|
||||
return flat_weights
|
||||
|
||||
def forward(self, input, hx=None):
|
||||
# This function is copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
||||
# The only change is for calling `_VF.lstm()`:
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
):
|
||||
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
||||
# The change for calling `_VF.lstm()` is:
|
||||
# self._flat_weights -> self.get_flat_weights()
|
||||
orig_input = input
|
||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
input, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
max_batch_size = batch_sizes[0]
|
||||
max_batch_size = int(max_batch_size)
|
||||
else:
|
||||
batch_sizes = None
|
||||
max_batch_size = (
|
||||
input.size(0) if self.batch_first else input.size(1)
|
||||
)
|
||||
sorted_indices = None
|
||||
unsorted_indices = None
|
||||
|
||||
if hx is None:
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
real_hidden_size = (
|
||||
self.proj_size if self.proj_size > 0 else self.hidden_size
|
||||
)
|
||||
h_zeros = torch.zeros(
|
||||
self.num_layers * num_directions,
|
||||
max_batch_size,
|
||||
real_hidden_size,
|
||||
1,
|
||||
input.size(1),
|
||||
self.hidden_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
c_zeros = torch.zeros(
|
||||
self.num_layers * num_directions,
|
||||
max_batch_size,
|
||||
1,
|
||||
input.size(1),
|
||||
self.hidden_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
hx = (h_zeros, c_zeros)
|
||||
else:
|
||||
# Each batch of the hidden state should match the input sequence that
|
||||
# the user believes he/she is passing in.
|
||||
hx = self.permute_hidden(hx, sorted_indices)
|
||||
|
||||
self.check_forward_args(input, hx, batch_sizes)
|
||||
if batch_sizes is None:
|
||||
self.check_forward_args(input, hx, None)
|
||||
result = _VF.lstm(
|
||||
input,
|
||||
hx,
|
||||
@ -484,28 +462,10 @@ class ScaledLSTM(nn.LSTM):
|
||||
self.bidirectional,
|
||||
self.batch_first,
|
||||
)
|
||||
else:
|
||||
result = _VF.lstm(
|
||||
input,
|
||||
batch_sizes,
|
||||
hx,
|
||||
self.get_flat_weights(),
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
self.training,
|
||||
self.bidirectional,
|
||||
)
|
||||
|
||||
output = result[0]
|
||||
hidden = result[1:]
|
||||
# xxx: isinstance check needs to be in conditional for TorchScript to compile
|
||||
if isinstance(orig_input, PackedSequence):
|
||||
output_packed = PackedSequence(
|
||||
output, batch_sizes, sorted_indices, unsorted_indices
|
||||
)
|
||||
return output_packed, self.permute_hidden(hidden, unsorted_indices)
|
||||
else:
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
return output, hidden
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user