mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix memory-dim issues
This commit is contained in:
parent
7886da9b59
commit
4dc972899b
@ -19,7 +19,7 @@
|
||||
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
||||
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -50,6 +50,7 @@ class AttentionDecoderModel(nn.Module):
|
||||
attention_dim: int = 512,
|
||||
nhead: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
sos_id: int = 1,
|
||||
eos_id: int = 1,
|
||||
dropout: float = 0.1,
|
||||
@ -70,6 +71,7 @@ class AttentionDecoderModel(nn.Module):
|
||||
attention_dim=attention_dim,
|
||||
nhead=nhead,
|
||||
feedforward_dim=feedforward_dim,
|
||||
memory_dim=memory_dim,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
@ -170,6 +172,7 @@ class TransformerDecoder(nn.Module):
|
||||
attention_dim: int = 512,
|
||||
nhead: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
@ -181,7 +184,9 @@ class TransformerDecoder(nn.Module):
|
||||
self.num_layers = num_decoder_layers
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderLayer(d_model, attention_dim, nhead, feedforward_dim, dropout)
|
||||
DecoderLayer(
|
||||
d_model, attention_dim, nhead, feedforward_dim, memory_dim, dropout
|
||||
)
|
||||
for _ in range(num_decoder_layers)
|
||||
]
|
||||
)
|
||||
@ -243,6 +248,7 @@ class DecoderLayer(nn.Module):
|
||||
attention_dim: int = 512,
|
||||
nhead: int = 8,
|
||||
feedforward_dim: int = 2048,
|
||||
memory_dim: int = 512,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
@ -252,7 +258,7 @@ class DecoderLayer(nn.Module):
|
||||
self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
||||
|
||||
self.norm_src_attn = nn.LayerNorm(d_model)
|
||||
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
||||
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, memory_dim=memory_dim, dropout=0.0)
|
||||
|
||||
self.norm_ff = nn.LayerNorm(d_model)
|
||||
self.feed_forward = nn.Sequential(
|
||||
@ -301,7 +307,12 @@ class MultiHeadedAttention(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embed_dim: int, attention_dim: int, num_heads: int, dropout: float = 0.0
|
||||
self,
|
||||
embed_dim: int,
|
||||
attention_dim: int,
|
||||
num_heads: int,
|
||||
memory_dim: Optional[int] = None,
|
||||
dropout: float = 0.0
|
||||
):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
@ -317,8 +328,12 @@ class MultiHeadedAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||
self.linear_k = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||
self.linear_v = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||
self.linear_k = nn.Linear(
|
||||
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
|
||||
)
|
||||
self.linear_v = nn.Linear(
|
||||
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
|
||||
)
|
||||
self.scale = math.sqrt(self.head_dim)
|
||||
|
||||
self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True)
|
||||
@ -475,6 +490,7 @@ def _test_attention_decoder_model():
|
||||
attention_dim=512,
|
||||
nhead=8,
|
||||
feedforward_dim=2048,
|
||||
memory_dim=384,
|
||||
dropout=0.1,
|
||||
sos_id=1,
|
||||
eos_id=1,
|
||||
@ -485,7 +501,7 @@ def _test_attention_decoder_model():
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
m.eval()
|
||||
encoder_out = torch.randn(2, 50, 512)
|
||||
encoder_out = torch.randn(2, 50, 384)
|
||||
encoder_out_lens = torch.full((2,), 50)
|
||||
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
|
||||
|
||||
|
@ -665,9 +665,6 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
|
||||
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder_dim = max(_to_int_tuple(params.encoder_dim))
|
||||
assert params.attention_decoder_dim == encoder_dim, (params.attention_decoder_dim, encoder_dim)
|
||||
|
||||
decoder = AttentionDecoderModel(
|
||||
vocab_size=params.vocab_size,
|
||||
decoder_dim=params.attention_decoder_dim,
|
||||
@ -675,6 +672,7 @@ def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
attention_dim=params.attention_decoder_attention_dim,
|
||||
nhead=params.attention_decoder_num_heads,
|
||||
feedforward_dim=params.attention_decoder_feedforward_dim,
|
||||
memory_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
ignore_id=params.ignore_id,
|
||||
|
@ -1298,6 +1298,7 @@ def rescore_with_attention_decoder_no_ngram(
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user