add using proj_size

This commit is contained in:
yaozengwei 2022-07-25 19:04:08 +08:00
parent 9bb0c7988f
commit 6871c96ffa
3 changed files with 139 additions and 57 deletions

View File

@ -38,9 +38,11 @@ class RNN(EncoderInterface):
subsampling_factor (int): subsampling_factor (int):
Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
d_model (int): d_model (int):
Hidden dimension for lstm layers, also output dimension (default=512). Output dimension (default=512).
dim_feedforward (int): dim_feedforward (int):
Feedforward dimension (default=2048). Feedforward dimension (default=2048).
rnn_hidden_size (int):
Hidden dimension for lstm layers (default=1024).
num_encoder_layers (int): num_encoder_layers (int):
Number of encoder layers (default=12). Number of encoder layers (default=12).
dropout (float): dropout (float):
@ -58,6 +60,7 @@ class RNN(EncoderInterface):
subsampling_factor: int = 4, subsampling_factor: int = 4,
d_model: int = 512, d_model: int = 512,
dim_feedforward: int = 2048, dim_feedforward: int = 2048,
rnn_hidden_size: int = 1024,
num_encoder_layers: int = 12, num_encoder_layers: int = 12,
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
@ -79,9 +82,14 @@ class RNN(EncoderInterface):
self.num_encoder_layers = num_encoder_layers self.num_encoder_layers = num_encoder_layers
self.d_model = d_model self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
encoder_layer = RNNEncoderLayer( encoder_layer = RNNEncoderLayer(
d_model, dim_feedforward, dropout, layer_dropout d_model=d_model,
dim_feedforward=dim_feedforward,
rnn_hidden_size=rnn_hidden_size,
dropout=dropout,
layer_dropout=layer_dropout,
) )
self.encoder = RNNEncoder( self.encoder = RNNEncoder(
encoder_layer, encoder_layer,
@ -135,17 +143,26 @@ class RNN(EncoderInterface):
return x, lengths return x, lengths
@torch.jit.export @torch.jit.export
def get_init_states(self, device: torch.device) -> torch.Tensor: def get_init_states(
self, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get model initial states.""" """Get model initial states."""
init_states = torch.zeros( # for rnn hidden states
(2, self.num_encoder_layers, self.d_model), device=device hidden_states = torch.zeros(
(self.num_encoder_layers, self.d_model), device=device
) )
return init_states cell_states = torch.zeros(
(self.num_encoder_layers, self.rnn_hidden_size), device=device
)
return (hidden_states, cell_states)
@torch.jit.export @torch.jit.export
def infer( def infer(
self, x: torch.Tensor, x_lens: torch.Tensor, states: torch.Tensor self,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x: torch.Tensor,
x_lens: torch.Tensor,
states: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
Args: Args:
x: x:
@ -155,9 +172,11 @@ class RNN(EncoderInterface):
A tensor of shape (N,), containing the number of frames in `x` A tensor of shape (N,), containing the number of frames in `x`
before padding. before padding.
states: states:
Its shape is (2, num_encoder_layers, N, E). It is a list of 2 tensors.
states[0] and states[1] are cached hidden states and cell states for states[0] is the hidden states of all layers,
all layers, respectively. with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
Returns: Returns:
A tuple of 3 tensors: A tuple of 3 tensors:
@ -165,15 +184,22 @@ class RNN(EncoderInterface):
sequence lengths. sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of - lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding. frames in `embeddings` before padding.
- updated states, with shape of (2, num_encoder_layers, N, E). - updated states, whose shape is same as the input states.
""" """
assert not self.training assert not self.training
assert states.shape == ( assert len(states) == 2
2, # for hidden state
assert states[0].shape == (
self.num_encoder_layers, self.num_encoder_layers,
x.size(0), x.size(0),
self.d_model, self.d_model,
), states.shape )
# for cell state
assert states[1].shape == (
self.num_encoder_layers,
x.size(0),
self.rnn_hidden_size,
)
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
# #
@ -201,6 +227,8 @@ class RNNEncoderLayer(nn.Module):
The number of expected features in the input (required). The number of expected features in the input (required).
dim_feedforward: dim_feedforward:
The dimension of feedforward network model (default=2048). The dimension of feedforward network model (default=2048).
rnn_hidden_size:
The hidden dimension of rnn layer.
dropout: dropout:
The dropout value (default=0.1). The dropout value (default=0.1).
layer_dropout: layer_dropout:
@ -211,15 +239,22 @@ class RNNEncoderLayer(nn.Module):
self, self,
d_model: int, d_model: int,
dim_feedforward: int, dim_feedforward: int,
rnn_hidden_size: int,
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
) -> None: ) -> None:
super(RNNEncoderLayer, self).__init__() super(RNNEncoderLayer, self).__init__()
self.layer_dropout = layer_dropout self.layer_dropout = layer_dropout
self.d_model = d_model self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
assert rnn_hidden_size >= d_model
self.lstm = ScaledLSTM( self.lstm = ScaledLSTM(
input_size=d_model, hidden_size=d_model, dropout=0.0 input_size=d_model,
hidden_size=rnn_hidden_size,
proj_size=d_model if rnn_hidden_size > d_model else 0,
num_layers=1,
dropout=0.0,
) )
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward), ScaledLinear(d_model, dim_feedforward),
@ -279,28 +314,30 @@ class RNNEncoderLayer(nn.Module):
@torch.jit.export @torch.jit.export
def infer( def infer(
self, src: torch.Tensor, states: torch.Tensor self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
Args: Args:
src: src:
The sequence to the encoder layer (required). The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length, Its shape is (S, N, d_model), where S is the sequence length,
N is the batch size, and E is the feature number. N is the batch size.
states: states:
Its shape is (2, 1, N, E). It is a tuple of 2 tensors.
states[0] and states[1] are cached hidden state and cell state, states[0] is the hidden state, with shape of (1, N, d_model);
respectively. states[1] is the cell state, with shape of (1, N, rnn_hidden_size).
""" """
assert not self.training assert not self.training
assert states.shape == (2, 1, src.size(1), src.size(2)) assert len(states) == 2
# for hidden state
assert states[0].shape == (1, src.size(1), self.d_model)
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
# lstm module # lstm module
# The required shapes of h_0 and c_0 are both (1, N, E). src_lstm, new_states = self.lstm(src, states)
src_lstm, new_states = self.lstm(src, (states[0], states[1]))
new_states = torch.stack(new_states, dim=0)
src = src + self.dropout(src_lstm) src = src + self.dropout(src_lstm)
# feed forward module # feed forward module
@ -333,6 +370,8 @@ class RNNEncoder(nn.Module):
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
self.num_layers = num_layers self.num_layers = num_layers
self.d_model = encoder_layer.d_model
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
self.use_random_combiner = False self.use_random_combiner = False
if aux_layers is not None: if aux_layers is not None:
@ -377,34 +416,55 @@ class RNNEncoder(nn.Module):
@torch.jit.export @torch.jit.export
def infer( def infer(
self, src: torch.Tensor, states: torch.Tensor self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
Args: Args:
src: src:
The sequence to the encoder layer (required). The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length, Its shape is (S, N, d_model), where S is the sequence length,
N is the batch size, and E is the feature number. N is the batch size.
states: states:
Its shape is (2, num_layers, N, E). It is a list of 2 tensors.
states[0] and states[1] are cached hidden states and cell states for states[0] is the hidden states of all layers,
all layers, respectively. with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
""" """
assert not self.training assert not self.training
assert states.shape == (2, self.num_layers, src.size(1), src.size(2)) assert len(states) == 2
# for hidden state
assert states[0].shape == (self.num_layers, src.size(1), self.d_model)
# for cell state
assert states[1].shape == (
self.num_layers,
src.size(1),
self.rnn_hidden_size,
)
new_states_list = []
output = src output = src
new_hidden_states = []
new_cell_states = []
for layer_index, mod in enumerate(self.layers): for layer_index, mod in enumerate(self.layers):
# new_states: (2, 1, N, E) layer_states = (
output, new_states = mod.infer( states[0][
output, states[:, layer_index : layer_index + 1, :, :] layer_index : layer_index + 1, :, :
], # h: (1, N, d_model)
states[1][
layer_index : layer_index + 1, :, :
], # c: (1, N, rnn_hidden_size)
) )
new_states_list.append(new_states) output, (h, c) = mod.infer(output, layer_states)
new_hidden_states.append(h)
new_cell_states.append(c)
return output, torch.cat(new_states_list, dim=1) new_states = (
torch.cat(new_hidden_states, dim=0),
torch.cat(new_cell_states, dim=0),
)
return output, new_states
class Conv2dSubsampling(nn.Module): class Conv2dSubsampling(nn.Module):
@ -740,8 +800,14 @@ def _test_random_combine_main():
if __name__ == "__main__": if __name__ == "__main__":
feature_dim = 50 feature_dim = 80
m = RNN(num_features=feature_dim, d_model=128) m = RNN(
num_features=feature_dim,
d_model=512,
rnn_hidden_size=1024,
dim_feedforward=2048,
num_encoder_layers=12,
)
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
@ -750,5 +816,7 @@ if __name__ == "__main__":
torch.full((batch_size,), seq_len, dtype=torch.int64), torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5, warmup=0.5,
) )
num_param = sum([p.numel() for p in m.parameters()])
print(f"Number of model parameters: {num_param}")
_test_random_combine_main() _test_random_combine_main()

View File

@ -90,16 +90,30 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=int, type=int,
default=20, default=12,
help="Number of RNN encoder layers..", help="Number of RNN encoder layers..",
) )
parser.add_argument(
"--encoder-dim",
type=int,
default=512,
help="Encoder output dimesion.",
)
parser.add_argument(
"--rnn-hidden-size",
type=int,
default=1024,
help="Hidden dim for LSTM layers.",
)
parser.add_argument( parser.add_argument(
"--aux-layer-period", "--aux-layer-period",
type=int, type=int,
default=3, default=3,
help="""Peroid of auxiliary layers used for randomly combined during training. help="""Peroid of auxiliary layers used for randomly combined during training.
If not larger than 0, will not use the random combiner. If not larger than 0 (e.g., -1), will not use the random combiner.
""", """,
) )
@ -340,8 +354,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
@ -359,7 +371,6 @@ def get_params() -> AttributeDict:
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_dim": 512,
"dim_feedforward": 2048, "dim_feedforward": 2048,
# parameters for decoder # parameters for decoder
"decoder_dim": 512, "decoder_dim": 512,
@ -380,6 +391,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_features=params.feature_dim, num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.encoder_dim, d_model=params.encoder_dim,
rnn_hidden_size=params.rnn_hidden_size,
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
aux_layer_period=params.aux_layer_period, aux_layer_period=params.aux_layer_period,
@ -837,7 +849,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if params.full_libri is False: if params.full_libri is False:
params.valid_interval = 1600 params.valid_interval = 800
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
@ -903,6 +915,10 @@ def run(rank, world_size, args):
logging.info("Loading scheduler state dict") logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
# # overwrite it
# scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs]
# print(scheduler.base_lrs)
if params.print_diagnostics: if params.print_diagnostics:
diagnostic = diagnostics.attach_diagnostics(model) diagnostic = diagnostics.attach_diagnostics(model)

View File

@ -379,7 +379,7 @@ class ScaledConv2d(nn.Conv2d):
class ScaledLSTM(nn.LSTM): class ScaledLSTM(nn.LSTM):
# See docs for ScaledLinear. # See docs for ScaledLinear.
# This class implements single-layer LSTM with scaling mechanism, using `torch._VF.lstm` # This class implements LSTM with scaling mechanism, using `torch._VF.lstm`
# Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
def __init__( def __init__(
self, self,
@ -388,10 +388,8 @@ class ScaledLSTM(nn.LSTM):
initial_speed: float = 1.0, initial_speed: float = 1.0,
**kwargs **kwargs
): ):
# Hardcode num_layers=1, bidirectional=False, proj_size=0 here # Hardcode bidirectional=False
super(ScaledLSTM, self).__init__( super(ScaledLSTM, self).__init__(*args, bidirectional=False, **kwargs)
*args, num_layers=1, bidirectional=False, proj_size=0, **kwargs
)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
self._scales_names = [] self._scales_names = []
self._scales = [] self._scales = []
@ -495,14 +493,14 @@ class ScaledLSTM(nn.LSTM):
# self._flat_weights -> self._get_flat_weights() # self._flat_weights -> self._get_flat_weights()
if hx is None: if hx is None:
h_zeros = torch.zeros( h_zeros = torch.zeros(
1, self.num_layers,
input.size(1), input.size(1),
self.hidden_size, self.proj_size if self.proj_size > 0 else self.hidden_size,
dtype=input.dtype, dtype=input.dtype,
device=input.device, device=input.device,
) )
c_zeros = torch.zeros( c_zeros = torch.zeros(
1, self.num_layers,
input.size(1), input.size(1),
self.hidden_size, self.hidden_size,
dtype=input.dtype, dtype=input.dtype,