fix memory-dim issues

This commit is contained in:
yaozengwei 2023-11-21 11:59:01 +08:00
parent 7886da9b59
commit 4dc972899b
3 changed files with 25 additions and 10 deletions

View File

@ -19,7 +19,7 @@
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
import math
from typing import List, Tuple
from typing import List, Optional, Tuple
import k2
import torch
@ -50,6 +50,7 @@ class AttentionDecoderModel(nn.Module):
attention_dim: int = 512,
nhead: int = 8,
feedforward_dim: int = 2048,
memory_dim: int = 512,
sos_id: int = 1,
eos_id: int = 1,
dropout: float = 0.1,
@ -70,6 +71,7 @@ class AttentionDecoderModel(nn.Module):
attention_dim=attention_dim,
nhead=nhead,
feedforward_dim=feedforward_dim,
memory_dim=memory_dim,
dropout=dropout,
)
@ -170,6 +172,7 @@ class TransformerDecoder(nn.Module):
attention_dim: int = 512,
nhead: int = 8,
feedforward_dim: int = 2048,
memory_dim: int = 512,
dropout: float = 0.1,
):
super().__init__()
@ -181,7 +184,9 @@ class TransformerDecoder(nn.Module):
self.num_layers = num_decoder_layers
self.layers = nn.ModuleList(
[
DecoderLayer(d_model, attention_dim, nhead, feedforward_dim, dropout)
DecoderLayer(
d_model, attention_dim, nhead, feedforward_dim, memory_dim, dropout
)
for _ in range(num_decoder_layers)
]
)
@ -243,6 +248,7 @@ class DecoderLayer(nn.Module):
attention_dim: int = 512,
nhead: int = 8,
feedforward_dim: int = 2048,
memory_dim: int = 512,
dropout: float = 0.1,
):
"""Construct an DecoderLayer object."""
@ -252,7 +258,7 @@ class DecoderLayer(nn.Module):
self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
self.norm_src_attn = nn.LayerNorm(d_model)
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, memory_dim=memory_dim, dropout=0.0)
self.norm_ff = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
@ -301,7 +307,12 @@ class MultiHeadedAttention(nn.Module):
"""
def __init__(
self, embed_dim: int, attention_dim: int, num_heads: int, dropout: float = 0.0
self,
embed_dim: int,
attention_dim: int,
num_heads: int,
memory_dim: Optional[int] = None,
dropout: float = 0.0
):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
@ -317,8 +328,12 @@ class MultiHeadedAttention(nn.Module):
)
self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True)
self.linear_k = nn.Linear(embed_dim, attention_dim, bias=True)
self.linear_v = nn.Linear(embed_dim, attention_dim, bias=True)
self.linear_k = nn.Linear(
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
)
self.linear_v = nn.Linear(
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
)
self.scale = math.sqrt(self.head_dim)
self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True)
@ -475,6 +490,7 @@ def _test_attention_decoder_model():
attention_dim=512,
nhead=8,
feedforward_dim=2048,
memory_dim=384,
dropout=0.1,
sos_id=1,
eos_id=1,
@ -485,7 +501,7 @@ def _test_attention_decoder_model():
print(f"Number of model parameters: {num_param}")
m.eval()
encoder_out = torch.randn(2, 50, 512)
encoder_out = torch.randn(2, 50, 384)
encoder_out_lens = torch.full((2,), 50)
token_ids = [[1, 2, 3, 4], [2, 3, 10]]

View File

@ -665,9 +665,6 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
encoder_dim = max(_to_int_tuple(params.encoder_dim))
assert params.attention_decoder_dim == encoder_dim, (params.attention_decoder_dim, encoder_dim)
decoder = AttentionDecoderModel(
vocab_size=params.vocab_size,
decoder_dim=params.attention_decoder_dim,
@ -675,6 +672,7 @@ def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
attention_dim=params.attention_decoder_attention_dim,
nhead=params.attention_decoder_num_heads,
feedforward_dim=params.attention_decoder_feedforward_dim,
memory_dim=max(_to_int_tuple(params.encoder_dim)),
sos_id=params.sos_id,
eos_id=params.eos_id,
ignore_id=params.ignore_id,

View File

@ -1298,6 +1298,7 @@ def rescore_with_attention_decoder_no_ngram(
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0]
else:
attention_scale_list = [attention_scale]