mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add using proj_size
This commit is contained in:
parent
9bb0c7988f
commit
6871c96ffa
@ -38,9 +38,11 @@ class RNN(EncoderInterface):
|
||||
subsampling_factor (int):
|
||||
Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
|
||||
d_model (int):
|
||||
Hidden dimension for lstm layers, also output dimension (default=512).
|
||||
Output dimension (default=512).
|
||||
dim_feedforward (int):
|
||||
Feedforward dimension (default=2048).
|
||||
rnn_hidden_size (int):
|
||||
Hidden dimension for lstm layers (default=1024).
|
||||
num_encoder_layers (int):
|
||||
Number of encoder layers (default=12).
|
||||
dropout (float):
|
||||
@ -58,6 +60,7 @@ class RNN(EncoderInterface):
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 512,
|
||||
dim_feedforward: int = 2048,
|
||||
rnn_hidden_size: int = 1024,
|
||||
num_encoder_layers: int = 12,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
@ -79,9 +82,14 @@ class RNN(EncoderInterface):
|
||||
|
||||
self.num_encoder_layers = num_encoder_layers
|
||||
self.d_model = d_model
|
||||
self.rnn_hidden_size = rnn_hidden_size
|
||||
|
||||
encoder_layer = RNNEncoderLayer(
|
||||
d_model, dim_feedforward, dropout, layer_dropout
|
||||
d_model=d_model,
|
||||
dim_feedforward=dim_feedforward,
|
||||
rnn_hidden_size=rnn_hidden_size,
|
||||
dropout=dropout,
|
||||
layer_dropout=layer_dropout,
|
||||
)
|
||||
self.encoder = RNNEncoder(
|
||||
encoder_layer,
|
||||
@ -135,17 +143,26 @@ class RNN(EncoderInterface):
|
||||
return x, lengths
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_states(self, device: torch.device) -> torch.Tensor:
|
||||
def get_init_states(
|
||||
self, device: torch.device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get model initial states."""
|
||||
init_states = torch.zeros(
|
||||
(2, self.num_encoder_layers, self.d_model), device=device
|
||||
# for rnn hidden states
|
||||
hidden_states = torch.zeros(
|
||||
(self.num_encoder_layers, self.d_model), device=device
|
||||
)
|
||||
return init_states
|
||||
cell_states = torch.zeros(
|
||||
(self.num_encoder_layers, self.rnn_hidden_size), device=device
|
||||
)
|
||||
return (hidden_states, cell_states)
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -155,9 +172,11 @@ class RNN(EncoderInterface):
|
||||
A tensor of shape (N,), containing the number of frames in `x`
|
||||
before padding.
|
||||
states:
|
||||
Its shape is (2, num_encoder_layers, N, E).
|
||||
states[0] and states[1] are cached hidden states and cell states for
|
||||
all layers, respectively.
|
||||
It is a list of 2 tensors.
|
||||
states[0] is the hidden states of all layers,
|
||||
with shape of (num_layers, N, d_model);
|
||||
states[1] is the cell states of all layers,
|
||||
with shape of (num_layers, N, rnn_hidden_size).
|
||||
|
||||
Returns:
|
||||
A tuple of 3 tensors:
|
||||
@ -165,15 +184,22 @@ class RNN(EncoderInterface):
|
||||
sequence lengths.
|
||||
- lengths: a tensor of shape (batch_size,) containing the number of
|
||||
frames in `embeddings` before padding.
|
||||
- updated states, with shape of (2, num_encoder_layers, N, E).
|
||||
- updated states, whose shape is same as the input states.
|
||||
"""
|
||||
assert not self.training
|
||||
assert states.shape == (
|
||||
2,
|
||||
assert len(states) == 2
|
||||
# for hidden state
|
||||
assert states[0].shape == (
|
||||
self.num_encoder_layers,
|
||||
x.size(0),
|
||||
self.d_model,
|
||||
), states.shape
|
||||
)
|
||||
# for cell state
|
||||
assert states[1].shape == (
|
||||
self.num_encoder_layers,
|
||||
x.size(0),
|
||||
self.rnn_hidden_size,
|
||||
)
|
||||
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||
#
|
||||
@ -201,6 +227,8 @@ class RNNEncoderLayer(nn.Module):
|
||||
The number of expected features in the input (required).
|
||||
dim_feedforward:
|
||||
The dimension of feedforward network model (default=2048).
|
||||
rnn_hidden_size:
|
||||
The hidden dimension of rnn layer.
|
||||
dropout:
|
||||
The dropout value (default=0.1).
|
||||
layer_dropout:
|
||||
@ -211,15 +239,22 @@ class RNNEncoderLayer(nn.Module):
|
||||
self,
|
||||
d_model: int,
|
||||
dim_feedforward: int,
|
||||
rnn_hidden_size: int,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
) -> None:
|
||||
super(RNNEncoderLayer, self).__init__()
|
||||
self.layer_dropout = layer_dropout
|
||||
self.d_model = d_model
|
||||
self.rnn_hidden_size = rnn_hidden_size
|
||||
|
||||
assert rnn_hidden_size >= d_model
|
||||
self.lstm = ScaledLSTM(
|
||||
input_size=d_model, hidden_size=d_model, dropout=0.0
|
||||
input_size=d_model,
|
||||
hidden_size=rnn_hidden_size,
|
||||
proj_size=d_model if rnn_hidden_size > d_model else 0,
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
)
|
||||
self.feed_forward = nn.Sequential(
|
||||
ScaledLinear(d_model, dim_feedforward),
|
||||
@ -279,28 +314,30 @@ class RNNEncoderLayer(nn.Module):
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self, src: torch.Tensor, states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src:
|
||||
The sequence to the encoder layer (required).
|
||||
Its shape is (S, N, E), where S is the sequence length,
|
||||
N is the batch size, and E is the feature number.
|
||||
Its shape is (S, N, d_model), where S is the sequence length,
|
||||
N is the batch size.
|
||||
states:
|
||||
Its shape is (2, 1, N, E).
|
||||
states[0] and states[1] are cached hidden state and cell state,
|
||||
respectively.
|
||||
It is a tuple of 2 tensors.
|
||||
states[0] is the hidden state, with shape of (1, N, d_model);
|
||||
states[1] is the cell state, with shape of (1, N, rnn_hidden_size).
|
||||
"""
|
||||
assert not self.training
|
||||
assert states.shape == (2, 1, src.size(1), src.size(2))
|
||||
assert len(states) == 2
|
||||
# for hidden state
|
||||
assert states[0].shape == (1, src.size(1), self.d_model)
|
||||
# for cell state
|
||||
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
|
||||
|
||||
# lstm module
|
||||
# The required shapes of h_0 and c_0 are both (1, N, E).
|
||||
src_lstm, new_states = self.lstm(src, (states[0], states[1]))
|
||||
new_states = torch.stack(new_states, dim=0)
|
||||
src_lstm, new_states = self.lstm(src, states)
|
||||
src = src + self.dropout(src_lstm)
|
||||
|
||||
# feed forward module
|
||||
@ -333,6 +370,8 @@ class RNNEncoder(nn.Module):
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.d_model = encoder_layer.d_model
|
||||
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
|
||||
|
||||
self.use_random_combiner = False
|
||||
if aux_layers is not None:
|
||||
@ -377,34 +416,55 @@ class RNNEncoder(nn.Module):
|
||||
|
||||
@torch.jit.export
|
||||
def infer(
|
||||
self, src: torch.Tensor, states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src:
|
||||
The sequence to the encoder layer (required).
|
||||
Its shape is (S, N, E), where S is the sequence length,
|
||||
N is the batch size, and E is the feature number.
|
||||
Its shape is (S, N, d_model), where S is the sequence length,
|
||||
N is the batch size.
|
||||
states:
|
||||
Its shape is (2, num_layers, N, E).
|
||||
states[0] and states[1] are cached hidden states and cell states for
|
||||
all layers, respectively.
|
||||
It is a list of 2 tensors.
|
||||
states[0] is the hidden states of all layers,
|
||||
with shape of (num_layers, N, d_model);
|
||||
states[1] is the cell states of all layers,
|
||||
with shape of (num_layers, N, rnn_hidden_size).
|
||||
"""
|
||||
assert not self.training
|
||||
assert states.shape == (2, self.num_layers, src.size(1), src.size(2))
|
||||
assert len(states) == 2
|
||||
# for hidden state
|
||||
assert states[0].shape == (self.num_layers, src.size(1), self.d_model)
|
||||
# for cell state
|
||||
assert states[1].shape == (
|
||||
self.num_layers,
|
||||
src.size(1),
|
||||
self.rnn_hidden_size,
|
||||
)
|
||||
|
||||
new_states_list = []
|
||||
output = src
|
||||
new_hidden_states = []
|
||||
new_cell_states = []
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
# new_states: (2, 1, N, E)
|
||||
output, new_states = mod.infer(
|
||||
output, states[:, layer_index : layer_index + 1, :, :]
|
||||
layer_states = (
|
||||
states[0][
|
||||
layer_index : layer_index + 1, :, :
|
||||
], # h: (1, N, d_model)
|
||||
states[1][
|
||||
layer_index : layer_index + 1, :, :
|
||||
], # c: (1, N, rnn_hidden_size)
|
||||
)
|
||||
new_states_list.append(new_states)
|
||||
output, (h, c) = mod.infer(output, layer_states)
|
||||
new_hidden_states.append(h)
|
||||
new_cell_states.append(c)
|
||||
|
||||
return output, torch.cat(new_states_list, dim=1)
|
||||
new_states = (
|
||||
torch.cat(new_hidden_states, dim=0),
|
||||
torch.cat(new_cell_states, dim=0),
|
||||
)
|
||||
return output, new_states
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
@ -740,8 +800,14 @@ def _test_random_combine_main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
feature_dim = 50
|
||||
m = RNN(num_features=feature_dim, d_model=128)
|
||||
feature_dim = 80
|
||||
m = RNN(
|
||||
num_features=feature_dim,
|
||||
d_model=512,
|
||||
rnn_hidden_size=1024,
|
||||
dim_feedforward=2048,
|
||||
num_encoder_layers=12,
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
@ -750,5 +816,7 @@ if __name__ == "__main__":
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
||||
num_param = sum([p.numel() for p in m.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
_test_random_combine_main()
|
||||
|
@ -90,16 +90,30 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=int,
|
||||
default=20,
|
||||
default=12,
|
||||
help="Number of RNN encoder layers..",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Encoder output dimesion.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--rnn-hidden-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Hidden dim for LSTM layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--aux-layer-period",
|
||||
type=int,
|
||||
default=3,
|
||||
help="""Peroid of auxiliary layers used for randomly combined during training.
|
||||
If not larger than 0, will not use the random combiner.
|
||||
If not larger than 0 (e.g., -1), will not use the random combiner.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -340,8 +354,6 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- encoder_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
@ -359,7 +371,6 @@ def get_params() -> AttributeDict:
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"encoder_dim": 512,
|
||||
"dim_feedforward": 2048,
|
||||
# parameters for decoder
|
||||
"decoder_dim": 512,
|
||||
@ -380,6 +391,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.encoder_dim,
|
||||
rnn_hidden_size=params.rnn_hidden_size,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
aux_layer_period=params.aux_layer_period,
|
||||
@ -837,7 +849,7 @@ def run(rank, world_size, args):
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
if params.full_libri is False:
|
||||
params.valid_interval = 1600
|
||||
params.valid_interval = 800
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
if world_size > 1:
|
||||
@ -903,6 +915,10 @@ def run(rank, world_size, args):
|
||||
logging.info("Loading scheduler state dict")
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
# # overwrite it
|
||||
# scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs]
|
||||
# print(scheduler.base_lrs)
|
||||
|
||||
if params.print_diagnostics:
|
||||
diagnostic = diagnostics.attach_diagnostics(model)
|
||||
|
||||
|
@ -379,7 +379,7 @@ class ScaledConv2d(nn.Conv2d):
|
||||
|
||||
class ScaledLSTM(nn.LSTM):
|
||||
# See docs for ScaledLinear.
|
||||
# This class implements single-layer LSTM with scaling mechanism, using `torch._VF.lstm`
|
||||
# This class implements LSTM with scaling mechanism, using `torch._VF.lstm`
|
||||
# Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
|
||||
def __init__(
|
||||
self,
|
||||
@ -388,10 +388,8 @@ class ScaledLSTM(nn.LSTM):
|
||||
initial_speed: float = 1.0,
|
||||
**kwargs
|
||||
):
|
||||
# Hardcode num_layers=1, bidirectional=False, proj_size=0 here
|
||||
super(ScaledLSTM, self).__init__(
|
||||
*args, num_layers=1, bidirectional=False, proj_size=0, **kwargs
|
||||
)
|
||||
# Hardcode bidirectional=False
|
||||
super(ScaledLSTM, self).__init__(*args, bidirectional=False, **kwargs)
|
||||
initial_scale = torch.tensor(initial_scale).log()
|
||||
self._scales_names = []
|
||||
self._scales = []
|
||||
@ -495,14 +493,14 @@ class ScaledLSTM(nn.LSTM):
|
||||
# self._flat_weights -> self._get_flat_weights()
|
||||
if hx is None:
|
||||
h_zeros = torch.zeros(
|
||||
1,
|
||||
self.num_layers,
|
||||
input.size(1),
|
||||
self.hidden_size,
|
||||
self.proj_size if self.proj_size > 0 else self.hidden_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
c_zeros = torch.zeros(
|
||||
1,
|
||||
self.num_layers,
|
||||
input.size(1),
|
||||
self.hidden_size,
|
||||
dtype=input.dtype,
|
||||
|
Loading…
x
Reference in New Issue
Block a user