mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
refactor attention decoder
This commit is contained in:
parent
1503351833
commit
0be32f3da0
@ -36,7 +36,7 @@ class AttentionDecoderModel(nn.Module):
|
||||
vocab_size (int): Number of classes.
|
||||
decoder_dim: (int,int): embedding dimension of 2 encoder stacks
|
||||
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
||||
nhead (int, int): number of heads
|
||||
num_heads (int, int): number of heads
|
||||
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
@ -48,7 +48,7 @@ class AttentionDecoderModel(nn.Module):
|
||||
decoder_dim: int = 512,
|
||||
num_decoder_layers: int = 6,
|
||||
attention_dim: int = 512,
|
||||
nhead: int = 8,
|
||||
num_heads: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
sos_id: int = 1,
|
||||
@ -69,7 +69,7 @@ class AttentionDecoderModel(nn.Module):
|
||||
d_model=decoder_dim,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
attention_dim=attention_dim,
|
||||
nhead=nhead,
|
||||
num_heads=num_heads,
|
||||
feedforward_dim=feedforward_dim,
|
||||
memory_dim=memory_dim,
|
||||
dropout=dropout,
|
||||
@ -111,7 +111,12 @@ class AttentionDecoderModel(nn.Module):
|
||||
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens)
|
||||
|
||||
# decoder forward
|
||||
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
|
||||
decoder_out = self.decoder(
|
||||
x=ys_in_pad,
|
||||
x_lens=ys_in_lens,
|
||||
memory=encoder_out,
|
||||
memory_lens=encoder_out_lens,
|
||||
)
|
||||
|
||||
loss = self.loss_fun(x=decoder_out, target=ys_out_pad)
|
||||
return loss
|
||||
@ -137,7 +142,12 @@ class AttentionDecoderModel(nn.Module):
|
||||
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens)
|
||||
|
||||
# decoder forward
|
||||
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
|
||||
decoder_out = self.decoder(
|
||||
x=ys_in_pad,
|
||||
x_lens=ys_in_lens,
|
||||
memory=encoder_out,
|
||||
memory_lens=encoder_out_lens,
|
||||
)
|
||||
|
||||
batch_size, _, num_classes = decoder_out.size()
|
||||
nll = nn.functional.cross_entropy(
|
||||
@ -152,14 +162,13 @@ class AttentionDecoderModel(nn.Module):
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
"""Transfomer decoder module.
|
||||
It is modified from https://github.com/espnet/espnet/blob/master/espnet2/asr/decoder/transformer_decoder.py.
|
||||
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
d_model: decoder dimension
|
||||
num_decoder_layers: number of decoder layers
|
||||
attention_dim: total dimension of multi head attention
|
||||
n_head: number of attention heads
|
||||
num_heads: number of attention heads
|
||||
feedforward_dim: hidden dimension of feed_forward module
|
||||
dropout: dropout rate
|
||||
"""
|
||||
@ -170,7 +179,7 @@ class TransformerDecoder(nn.Module):
|
||||
d_model: int = 512,
|
||||
num_decoder_layers: int = 6,
|
||||
attention_dim: int = 512,
|
||||
nhead: int = 8,
|
||||
num_heads: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
dropout: float = 0.1,
|
||||
@ -178,14 +187,19 @@ class TransformerDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
|
||||
|
||||
# Using absolute positional encoding
|
||||
# Absolute positional encoding
|
||||
self.pos = PositionalEncoding(d_model, dropout_rate=0.1)
|
||||
|
||||
self.num_layers = num_decoder_layers
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderLayer(
|
||||
d_model, attention_dim, nhead, feedforward_dim, memory_dim, dropout
|
||||
d_model=d_model,
|
||||
attention_dim=attention_dim,
|
||||
num_heads=num_heads,
|
||||
feedforward_dim=feedforward_dim,
|
||||
memory_dim=memory_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
for _ in range(num_decoder_layers)
|
||||
]
|
||||
@ -195,49 +209,67 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_lens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
memory: Optional[torch.Tensor] = None,
|
||||
memory_lens: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
memory: encoded memory, (batch, maxlen_in, feat)
|
||||
memory_lens: (batch,)
|
||||
ys_in_pad: input token ids, (batch, maxlen_out)
|
||||
ys_in_lens: (batch, )
|
||||
x: Input tensor of shape (batch, tgt_len).
|
||||
x_lens: A tensor of shape (batch,) containing the number of tokens in `x`
|
||||
before padding.
|
||||
memory:
|
||||
Memory sequence of shape (batch, src_len, memory_dim).
|
||||
memory_lens:
|
||||
A tensor of shape (batch,) containing the number of frames in
|
||||
`memory` before padding.
|
||||
|
||||
Returns:
|
||||
tgt: decoded token score before softmax (batch, maxlen_out, vocab_size)
|
||||
Decoded token logits before softmax (batch, tgt_len, vocab_size)
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask | (~m)
|
||||
x = self.embed(x) # (batch, tgt_len, embed_dim)
|
||||
x = self.pos(x) # (batch, tgt_len, embed_dim)
|
||||
|
||||
memory_mask = make_pad_mask(memory_lens)[:, None, :].to(memory.device)
|
||||
x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim)
|
||||
|
||||
tgt = self.embed(tgt)
|
||||
tgt = self.pos(tgt)
|
||||
# construct attn_mask for self-attn modules
|
||||
padding_mask = make_pad_mask(x_lens) # (batch, tgt_len)
|
||||
causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len)
|
||||
attn_mask = torch.logical_or(
|
||||
padding_mask.unsqueeze(1), # (batch, 1, seq_len)
|
||||
torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len)
|
||||
) # (batch, seq_len, seq_len)
|
||||
|
||||
if memory is not None:
|
||||
memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim)
|
||||
# construct memory_attn_mask for cross-attn modules
|
||||
memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len)
|
||||
memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len)
|
||||
else:
|
||||
memory_attn_mask = None
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
tgt = mod(tgt, tgt_mask, memory, memory_mask)
|
||||
x = mod(
|
||||
x,
|
||||
attn_mask=attn_mask,
|
||||
memory=memory,
|
||||
memory_attn_mask=memory_attn_mask,
|
||||
)
|
||||
|
||||
tgt = self.output_layer(tgt)
|
||||
return tgt
|
||||
x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size)
|
||||
x = self.output_layer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Single decoder layer module.
|
||||
|
||||
Args:
|
||||
d_model: equal to encoder_dim
|
||||
d_model: equal to decoder_dim, total dimension of the decoder
|
||||
attention_dim: total dimension of multi head attention
|
||||
n_head: number of attention heads
|
||||
num_heads: number of attention heads
|
||||
feedforward_dim: hidden dimension of feed_forward module
|
||||
dropout: dropout rate
|
||||
"""
|
||||
@ -246,7 +278,7 @@ class DecoderLayer(nn.Module):
|
||||
self,
|
||||
d_model: int = 512,
|
||||
attention_dim: int = 512,
|
||||
nhead: int = 8,
|
||||
num_heads: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
dropout: float = 0.1,
|
||||
@ -255,10 +287,14 @@ class DecoderLayer(nn.Module):
|
||||
super(DecoderLayer, self).__init__()
|
||||
|
||||
self.norm_self_attn = nn.LayerNorm(d_model)
|
||||
self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
||||
self.self_attn = MultiHeadAttention(
|
||||
d_model, attention_dim, num_heads, dropout=0.0
|
||||
)
|
||||
|
||||
self.norm_src_attn = nn.LayerNorm(d_model)
|
||||
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, memory_dim=memory_dim, dropout=0.0)
|
||||
self.src_attn = MultiHeadAttention(
|
||||
d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0
|
||||
)
|
||||
|
||||
self.norm_ff = nn.LayerNorm(d_model)
|
||||
self.feed_forward = nn.Sequential(
|
||||
@ -270,40 +306,53 @@ class DecoderLayer(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask):
|
||||
"""Compute decoded features.
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
memory: Optional[torch.Tensor] = None,
|
||||
memory_attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
x: Input sequence of shape (seq_len, batch, embed_dim).
|
||||
attn_mask: A binary mask for self-attention module indicating which
|
||||
elements will be filled with -inf.
|
||||
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
|
||||
memory: Memory sequence of shape (seq_len, batch, memory_dim).
|
||||
memory_attn_mask: A binary mask for cross-attention module indicating which
|
||||
elements will be filled with -inf.
|
||||
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
|
||||
"""
|
||||
# self-attn module
|
||||
tgt_norm = self.norm_self_attn(tgt)
|
||||
tgt = tgt + self.dropout(self.self_attn(tgt_norm, tgt_norm, tgt_norm, tgt_mask))
|
||||
qkv = self.norm_self_attn(x)
|
||||
self_attn_out = self.self_attn(
|
||||
query=qkv, key=qkv, value=qkv, attn_mask=attn_mask
|
||||
)
|
||||
x = x + self.dropout(self_attn_out)
|
||||
|
||||
# cross-attn module
|
||||
tgt = tgt + self.dropout(self.src_attn(self.norm_src_attn(tgt), memory, memory, memory_mask))
|
||||
q = self.norm_src_attn(x)
|
||||
src_attn_out = self.src_attn(
|
||||
query=q, key=memory, value=memory, attn_mask=memory_attn_mask
|
||||
)
|
||||
x = x + self.dropout(src_attn_out)
|
||||
|
||||
# feed-forward module
|
||||
tgt = tgt + self.dropout(self.feed_forward(self.norm_ff(tgt)))
|
||||
x = x + self.dropout(self.feed_forward(self.norm_ff(x)))
|
||||
|
||||
return tgt
|
||||
return x
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
attention_dim: dimension in the attention module, may be less or more than embed_dim
|
||||
but must be a multiple of num_heads.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
attention_dim: dimension in the attention module, but must be a multiple of num_heads.
|
||||
num_heads: number of parallel attention heads.
|
||||
memory_dim: dimension of memory embedding, optional.
|
||||
dropout: a Dropout layer on attn_output_weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -312,20 +361,18 @@ class MultiHeadedAttention(nn.Module):
|
||||
attention_dim: int,
|
||||
num_heads: int,
|
||||
memory_dim: Optional[int] = None,
|
||||
dropout: float = 0.0
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.attention_dim = attention_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = attention_dim // num_heads
|
||||
assert self.head_dim * num_heads == attention_dim, (
|
||||
self.head_dim,
|
||||
num_heads,
|
||||
attention_dim,
|
||||
self.head_dim, num_heads, attention_dim
|
||||
)
|
||||
self.dropout = dropout
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
|
||||
self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||
self.linear_k = nn.Linear(
|
||||
@ -334,74 +381,89 @@ class MultiHeadedAttention(nn.Module):
|
||||
self.linear_v = nn.Linear(
|
||||
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
|
||||
)
|
||||
self.scale = math.sqrt(self.head_dim)
|
||||
|
||||
self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute scaled dot product attention.
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
query: Query tensor of shape (tgt_len, batch, embed_dim).
|
||||
key: Key tensor of shape (src_len, batch, embed_dim or memory_dim).
|
||||
value: Value tensor of shape (src_len, batch, embed_dim or memory_dim).
|
||||
key_padding_mask: A binary mask indicating which elements are padding.
|
||||
Its shape is (batch, src_len).
|
||||
attn_mask: A binary mask indicating which elements will be filled with -inf.
|
||||
Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
Output tensor of shape (tgt_len, batch, embed_dim).
|
||||
"""
|
||||
bsz, tgt_len, _ = query.size()
|
||||
src_len = key.size(1)
|
||||
num_heads = self.num_heads
|
||||
head_dim = self.head_dim
|
||||
|
||||
q = self.linear_q(query)
|
||||
k = self.linear_k(key)
|
||||
v = self.linear_v(value)
|
||||
tgt_len, batch, _ = query.shape
|
||||
src_len = key.shape[0]
|
||||
|
||||
q = q.reshape(bsz, tgt_len, num_heads, head_dim)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, head_dim)
|
||||
k = k.reshape(bsz, src_len, num_heads, head_dim)
|
||||
k = k.permute(0, 2, 3, 1) # (batch, head, head_dim, time2)
|
||||
v = v.reshape(bsz, src_len, num_heads, head_dim)
|
||||
v = v.transpose(1, 2).reshape(bsz * num_heads, src_len, head_dim)
|
||||
q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim)
|
||||
k = self.linear_k(key) # (src_len, batch, num_heads * head_dim)
|
||||
v = self.linear_v(value) # (src_len, batch, num_heads * head_dim)
|
||||
|
||||
# (batch, head, time1, time2)
|
||||
attn_output_weights = torch.matmul(q, k) / self.scale
|
||||
q = q.reshape(tgt_len, batch, num_heads, head_dim)
|
||||
q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim)
|
||||
k = k.reshape(src_len, batch, num_heads, head_dim)
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len)
|
||||
v = v.reshape(src_len, batch, num_heads, head_dim)
|
||||
v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
# attn_output_weights = torch.matmul(q, k)
|
||||
# # This is a harder way of limiting the attention scores to not be too large.
|
||||
# # It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||
# # this should be outside the normal range of the attention scores. We use
|
||||
# # this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||
# # gets very small gradients through the softmax can become very small, and
|
||||
# # some mechanisms like that become ineffective.
|
||||
attn_output_weights = penalize_abs_values_gt(
|
||||
attn_output_weights, limit=50.0, penalty=1.0e-04
|
||||
)
|
||||
# Note: could remove the scaling operation when using ScaledAdam
|
||||
# (batch, head, tgt_len, src_len)
|
||||
attn_weights = torch.matmul(q, k) / math.sqrt(head_dim)
|
||||
|
||||
if mask is not None:
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
mask.unsqueeze(1), float("-inf")
|
||||
# From zipformer.py:
|
||||
# This is a harder way of limiting the attention scores to not be too large.
|
||||
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||
# this should be outside the normal range of the attention scores. We use
|
||||
# this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||
# gets very small gradients through the softmax can become very small, and
|
||||
# some mechanisms like that become ineffective.
|
||||
attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),
|
||||
)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.shape == (batch, 1, src_len)
|
||||
or attn_mask.shape == (batch, tgt_len, src_len)
|
||||
), attn_mask.shape
|
||||
attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf"))
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=self.dropout, training=self.training
|
||||
attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len)
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
# (bsz * head, time1, head_dim_v)
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert attn_output.shape == (bsz * num_heads, tgt_len, head_dim)
|
||||
attn_output = (
|
||||
attn_output.reshape(bsz, num_heads, tgt_len, head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(bsz, tgt_len, self.attention_dim)
|
||||
)
|
||||
# (batch * head, tgt_len, head_dim)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape
|
||||
|
||||
attn_output = attn_output.transpose(0, 1).contiguous()
|
||||
attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim)
|
||||
|
||||
# (batch, tgt_len, embed_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
@ -488,7 +550,7 @@ def _test_attention_decoder_model():
|
||||
decoder_dim=512,
|
||||
num_decoder_layers=6,
|
||||
attention_dim=512,
|
||||
nhead=8,
|
||||
num_heads=8,
|
||||
feedforward_dim=2048,
|
||||
memory_dim=384,
|
||||
dropout=0.1,
|
||||
|
@ -662,7 +662,7 @@ def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.attention_decoder_dim,
|
||||
num_decoder_layers=params.attention_decoder_num_layers,
|
||||
attention_dim=params.attention_decoder_attention_dim,
|
||||
nhead=params.attention_decoder_num_heads,
|
||||
num_heads=params.attention_decoder_num_heads,
|
||||
feedforward_dim=params.attention_decoder_feedforward_dim,
|
||||
memory_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
sos_id=params.sos_id,
|
||||
|
Loading…
x
Reference in New Issue
Block a user