Make attention dims configurable, not embed_dim//2, trying 256.

This commit is contained in:
Daniel Povey 2022-10-17 11:03:29 +08:00
parent 325f5539f9
commit 03fe1ed200
2 changed files with 52 additions and 36 deletions

View File

@ -44,9 +44,10 @@ class Conformer(EncoderInterface):
Args:
num_features (int): Number of input features
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): embedding dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
d_model: (int,int): embedding dimension of 2 encoder stacks
attention_dim: (int,int): attention dimension of 2 encoder stacks
nhead (int, int): number of heads
dim_feedforward (int, int): feedforward dimention in 2 encoder stacks
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
@ -60,6 +61,7 @@ class Conformer(EncoderInterface):
subsampling_factor: int = 4,
conformer_subsampling_factor: int = 4,
d_model: Tuple[int] = (384, 384),
attention_dim: Tuple[int] = (256, 256),
encoder_unmasked_dim: int = 256,
nhead: Tuple[int] = (8, 8),
feedforward_dim: Tuple[int] = (1536, 2048),
@ -92,6 +94,7 @@ class Conformer(EncoderInterface):
encoder_layer1 = ConformerEncoderLayer(
d_model[0],
attention_dim[0],
nhead[0],
feedforward_dim[0],
dropout,
@ -110,6 +113,7 @@ class Conformer(EncoderInterface):
)
encoder_layer2 = ConformerEncoderLayer(
d_model[1],
attention_dim[1],
nhead[1],
feedforward_dim[1],
dropout,
@ -248,19 +252,20 @@ class ConformerEncoderLayer(nn.Module):
>>> out = encoder_layer(src, pos_emb)
"""
def __init__(
self,
d_model: int,
nhead: int,
feedforward_dim: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
self,
d_model: int,
attention_dim: int,
nhead: int,
feedforward_dim: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.d_model = d_model
self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0,
d_model, attention_dim, nhead, dropout=0.0,
)
self.feed_forward1 = FeedforwardModule(d_model,
@ -807,6 +812,8 @@ class RelPositionMultiheadAttention(nn.Module):
Args:
embed_dim: total dimension of the model.
attention_dim: dimension in the attention module, may be less or more than embed_dim
but must be a multiple of num_heads.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
@ -819,19 +826,21 @@ class RelPositionMultiheadAttention(nn.Module):
def __init__(
self,
embed_dim: int,
attention_dim: int,
num_heads: int,
dropout: float = 0.0,
) -> None:
super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.attention_dim = attention_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // (num_heads * 2)
self.head_dim = attention_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim // 2
self.head_dim * num_heads == attention_dim
), "embed_dim//2 must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
self.in_proj = nn.Linear(embed_dim, 3 * attention_dim, bias=True)
# self.whiten_values is applied on the values in forward()
self.whiten_values = Whiten(num_groups=num_heads,
@ -845,14 +854,14 @@ class RelPositionMultiheadAttention(nn.Module):
grad_scale=0.025)
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
self.in_balancer = ActivationBalancer(3 * attention_dim,
channel_dim=-1, max_abs=5.0)
self.out_proj = ScaledLinear(
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
attention_dim, embed_dim, bias=True, initial_scale=0.05
)
self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False)
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
self.in_proj2 = nn.Linear(embed_dim, attention_dim, bias=False)
self.out_proj2 = ScaledLinear(attention_dim, embed_dim, bias=True,
initial_scale=0.05)
# self.whiten_values2 is applied on the values in forward2()
self.whiten_values2 = Whiten(num_groups=num_heads,
@ -914,7 +923,7 @@ class RelPositionMultiheadAttention(nn.Module):
x, weights = self.multi_head_attention_forward(
self.in_balancer(self.in_proj(x)),
self.linear_pos(pos_emb),
self.embed_dim,
self.attention_dim,
self.num_heads,
self.in_proj.weight,
self.in_proj.bias,
@ -965,7 +974,7 @@ class RelPositionMultiheadAttention(nn.Module):
self,
x: Tensor,
pos: Tensor,
embed_dim: int,
attention_dim: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
@ -980,7 +989,7 @@ class RelPositionMultiheadAttention(nn.Module):
Args:
x_proj: the projected input, to be split into query, key, value.
pos: head-specific biases arising from the positional embeddings.
embed_dim: total dimension of the model.
attention_dim: dimension inside attention mechanism
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
dropout_p: probability of an element to be zeroed.
@ -994,8 +1003,8 @@ class RelPositionMultiheadAttention(nn.Module):
Shape:
Inputs:
- x: :math:`(L, N, 3 * E//2)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. Will be split into (query, key, value).
- x: :math:`(L, N, 3 * A)` where L is the target sequence length, N is the batch size, A is
the attention dimension. Will be split into (query, key, value).
- pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence
length, N is the batch size, and H is the number of heads.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
@ -1019,10 +1028,10 @@ class RelPositionMultiheadAttention(nn.Module):
seq_len, bsz, _ = x.size()
head_dim = embed_dim // (num_heads * 2)
head_dim = attention_dim // num_heads
assert (
head_dim * num_heads == embed_dim // 2
), "embed_dim must be divisible by num_heads"
head_dim * num_heads == attention_dim
), "attention_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
@ -1142,7 +1151,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, bsz, embed_dim // 2)
.view(seq_len, bsz, attention_dim)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
@ -1167,7 +1176,7 @@ class RelPositionMultiheadAttention(nn.Module):
"""
num_heads = self.num_heads
(seq_len, bsz, embed_dim) = x.shape
head_dim = embed_dim // (num_heads * 2)
head_dim = self.attention_dim // num_heads
# v: (tgt_len, bsz, embed_dim // 2)
v = self.in_proj2(x)
v = self.whiten_values2(v) # does nothing in the forward pass.
@ -1183,7 +1192,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, bsz, embed_dim // 2)
.view(seq_len, bsz, self.attention_dim)
)
# returned value is of shape (seq_len, bsz, embed_dim), like x.
return self.out_proj2(attn_output)

View File

@ -114,8 +114,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--encoder-dims",
type=str,
default="384,384",
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
"and the output dim of the encoder",
help="Embedding dimension in the 2 blocks of conformer encoder layers, comma separated"
)
parser.add_argument(
"--attention-dims",
type=str,
default="256,256",
help="Attention dimension in the 2 blocks of conformer encoder layers, comma separated"
)
parser.add_argument(
@ -418,17 +424,18 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
def to_int_list(s: str):
return list(map(int, s.split(',')))
def to_int_tuple(s: str):
return tuple(map(int, s.split(',')))
encoder = Conformer(
num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor,
conformer_subsampling_factor=params.conformer_subsampling_factor,
d_model=to_int_list(params.encoder_dims),
d_model=to_int_tuple(params.encoder_dims),
attention_dim=to_int_tuple(params.attention_dims),
encoder_unmasked_dim=params.encoder_unmasked_dim,
nhead=to_int_list(params.nhead),
feedforward_dim=to_int_list(params.feedforward_dims),
num_encoder_layers=to_int_list(params.num_encoder_layers),
nhead=to_int_tuple(params.nhead),
feedforward_dim=to_int_tuple(params.feedforward_dims),
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
)
return encoder