mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make attention dims configurable, not embed_dim//2, trying 256.
This commit is contained in:
parent
325f5539f9
commit
03fe1ed200
@ -44,9 +44,10 @@ class Conformer(EncoderInterface):
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): embedding dimension
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
d_model: (int,int): embedding dimension of 2 encoder stacks
|
||||
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
||||
nhead (int, int): number of heads
|
||||
dim_feedforward (int, int): feedforward dimention in 2 encoder stacks
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
@ -60,6 +61,7 @@ class Conformer(EncoderInterface):
|
||||
subsampling_factor: int = 4,
|
||||
conformer_subsampling_factor: int = 4,
|
||||
d_model: Tuple[int] = (384, 384),
|
||||
attention_dim: Tuple[int] = (256, 256),
|
||||
encoder_unmasked_dim: int = 256,
|
||||
nhead: Tuple[int] = (8, 8),
|
||||
feedforward_dim: Tuple[int] = (1536, 2048),
|
||||
@ -92,6 +94,7 @@ class Conformer(EncoderInterface):
|
||||
|
||||
encoder_layer1 = ConformerEncoderLayer(
|
||||
d_model[0],
|
||||
attention_dim[0],
|
||||
nhead[0],
|
||||
feedforward_dim[0],
|
||||
dropout,
|
||||
@ -110,6 +113,7 @@ class Conformer(EncoderInterface):
|
||||
)
|
||||
encoder_layer2 = ConformerEncoderLayer(
|
||||
d_model[1],
|
||||
attention_dim[1],
|
||||
nhead[1],
|
||||
feedforward_dim[1],
|
||||
dropout,
|
||||
@ -248,19 +252,20 @@ class ConformerEncoderLayer(nn.Module):
|
||||
>>> out = encoder_layer(src, pos_emb)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
feedforward_dim: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
self,
|
||||
d_model: int,
|
||||
attention_dim: int,
|
||||
nhead: int,
|
||||
feedforward_dim: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0,
|
||||
d_model, attention_dim, nhead, dropout=0.0,
|
||||
)
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(d_model,
|
||||
@ -807,6 +812,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
attention_dim: dimension in the attention module, may be less or more than embed_dim
|
||||
but must be a multiple of num_heads.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
|
||||
@ -819,19 +826,21 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
attention_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.attention_dim = attention_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // (num_heads * 2)
|
||||
self.head_dim = attention_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim // 2
|
||||
self.head_dim * num_heads == attention_dim
|
||||
), "embed_dim//2 must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * attention_dim, bias=True)
|
||||
|
||||
# self.whiten_values is applied on the values in forward()
|
||||
self.whiten_values = Whiten(num_groups=num_heads,
|
||||
@ -845,14 +854,14 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
grad_scale=0.025)
|
||||
|
||||
|
||||
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
||||
self.in_balancer = ActivationBalancer(3 * attention_dim,
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
embed_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||
attention_dim, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
|
||||
self.in_proj2 = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
||||
self.out_proj2 = ScaledLinear(embed_dim // 2, embed_dim, bias=True,
|
||||
self.in_proj2 = nn.Linear(embed_dim, attention_dim, bias=False)
|
||||
self.out_proj2 = ScaledLinear(attention_dim, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
# self.whiten_values2 is applied on the values in forward2()
|
||||
self.whiten_values2 = Whiten(num_groups=num_heads,
|
||||
@ -914,7 +923,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
x, weights = self.multi_head_attention_forward(
|
||||
self.in_balancer(self.in_proj(x)),
|
||||
self.linear_pos(pos_emb),
|
||||
self.embed_dim,
|
||||
self.attention_dim,
|
||||
self.num_heads,
|
||||
self.in_proj.weight,
|
||||
self.in_proj.bias,
|
||||
@ -965,7 +974,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
pos: Tensor,
|
||||
embed_dim: int,
|
||||
attention_dim: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor,
|
||||
@ -980,7 +989,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
Args:
|
||||
x_proj: the projected input, to be split into query, key, value.
|
||||
pos: head-specific biases arising from the positional embeddings.
|
||||
embed_dim: total dimension of the model.
|
||||
attention_dim: dimension inside attention mechanism
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
@ -994,8 +1003,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- x: :math:`(L, N, 3 * E//2)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension. Will be split into (query, key, value).
|
||||
- x: :math:`(L, N, 3 * A)` where L is the target sequence length, N is the batch size, A is
|
||||
the attention dimension. Will be split into (query, key, value).
|
||||
- pos: :math:`(N, 2*L-1, H)` or :math:`(1, 2*L-1, H)` where L is the sequence
|
||||
length, N is the batch size, and H is the number of heads.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
@ -1019,10 +1028,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
seq_len, bsz, _ = x.size()
|
||||
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
head_dim = attention_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim // 2
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
head_dim * num_heads == attention_dim
|
||||
), "attention_dim must be divisible by num_heads"
|
||||
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
@ -1142,7 +1151,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(seq_len, bsz, embed_dim // 2)
|
||||
.view(seq_len, bsz, attention_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
@ -1167,7 +1176,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
"""
|
||||
num_heads = self.num_heads
|
||||
(seq_len, bsz, embed_dim) = x.shape
|
||||
head_dim = embed_dim // (num_heads * 2)
|
||||
head_dim = self.attention_dim // num_heads
|
||||
# v: (tgt_len, bsz, embed_dim // 2)
|
||||
v = self.in_proj2(x)
|
||||
v = self.whiten_values2(v) # does nothing in the forward pass.
|
||||
@ -1183,7 +1192,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(seq_len, bsz, embed_dim // 2)
|
||||
.view(seq_len, bsz, self.attention_dim)
|
||||
)
|
||||
# returned value is of shape (seq_len, bsz, embed_dim), like x.
|
||||
return self.out_proj2(attn_output)
|
||||
|
||||
@ -114,8 +114,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
"--encoder-dims",
|
||||
type=str,
|
||||
default="384,384",
|
||||
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
|
||||
"and the output dim of the encoder",
|
||||
help="Embedding dimension in the 2 blocks of conformer encoder layers, comma separated"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-dims",
|
||||
type=str,
|
||||
default="256,256",
|
||||
help="Attention dimension in the 2 blocks of conformer encoder layers, comma separated"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -418,17 +424,18 @@ def get_params() -> AttributeDict:
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
def to_int_list(s: str):
|
||||
return list(map(int, s.split(',')))
|
||||
def to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(',')))
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
conformer_subsampling_factor=params.conformer_subsampling_factor,
|
||||
d_model=to_int_list(params.encoder_dims),
|
||||
d_model=to_int_tuple(params.encoder_dims),
|
||||
attention_dim=to_int_tuple(params.attention_dims),
|
||||
encoder_unmasked_dim=params.encoder_unmasked_dim,
|
||||
nhead=to_int_list(params.nhead),
|
||||
feedforward_dim=to_int_list(params.feedforward_dims),
|
||||
num_encoder_layers=to_int_list(params.num_encoder_layers),
|
||||
nhead=to_int_tuple(params.nhead),
|
||||
feedforward_dim=to_int_tuple(params.feedforward_dims),
|
||||
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user