mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
fix memory-dim issues
This commit is contained in:
parent
7886da9b59
commit
4dc972899b
@ -19,7 +19,7 @@
|
|||||||
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -50,6 +50,7 @@ class AttentionDecoderModel(nn.Module):
|
|||||||
attention_dim: int = 512,
|
attention_dim: int = 512,
|
||||||
nhead: int = 8,
|
nhead: int = 8,
|
||||||
feedforward_dim: int = 2048,
|
feedforward_dim: int = 2048,
|
||||||
|
memory_dim: int = 512,
|
||||||
sos_id: int = 1,
|
sos_id: int = 1,
|
||||||
eos_id: int = 1,
|
eos_id: int = 1,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
@ -70,6 +71,7 @@ class AttentionDecoderModel(nn.Module):
|
|||||||
attention_dim=attention_dim,
|
attention_dim=attention_dim,
|
||||||
nhead=nhead,
|
nhead=nhead,
|
||||||
feedforward_dim=feedforward_dim,
|
feedforward_dim=feedforward_dim,
|
||||||
|
memory_dim=memory_dim,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -170,6 +172,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
attention_dim: int = 512,
|
attention_dim: int = 512,
|
||||||
nhead: int = 8,
|
nhead: int = 8,
|
||||||
feedforward_dim: int = 2048,
|
feedforward_dim: int = 2048,
|
||||||
|
memory_dim: int = 512,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -181,7 +184,9 @@ class TransformerDecoder(nn.Module):
|
|||||||
self.num_layers = num_decoder_layers
|
self.num_layers = num_decoder_layers
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DecoderLayer(d_model, attention_dim, nhead, feedforward_dim, dropout)
|
DecoderLayer(
|
||||||
|
d_model, attention_dim, nhead, feedforward_dim, memory_dim, dropout
|
||||||
|
)
|
||||||
for _ in range(num_decoder_layers)
|
for _ in range(num_decoder_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -243,6 +248,7 @@ class DecoderLayer(nn.Module):
|
|||||||
attention_dim: int = 512,
|
attention_dim: int = 512,
|
||||||
nhead: int = 8,
|
nhead: int = 8,
|
||||||
feedforward_dim: int = 2048,
|
feedforward_dim: int = 2048,
|
||||||
|
memory_dim: int = 512,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
"""Construct an DecoderLayer object."""
|
"""Construct an DecoderLayer object."""
|
||||||
@ -252,7 +258,7 @@ class DecoderLayer(nn.Module):
|
|||||||
self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
||||||
|
|
||||||
self.norm_src_attn = nn.LayerNorm(d_model)
|
self.norm_src_attn = nn.LayerNorm(d_model)
|
||||||
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, memory_dim=memory_dim, dropout=0.0)
|
||||||
|
|
||||||
self.norm_ff = nn.LayerNorm(d_model)
|
self.norm_ff = nn.LayerNorm(d_model)
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
@ -301,7 +307,12 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, embed_dim: int, attention_dim: int, num_heads: int, dropout: float = 0.0
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
attention_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
memory_dim: Optional[int] = None,
|
||||||
|
dropout: float = 0.0
|
||||||
):
|
):
|
||||||
"""Construct an MultiHeadedAttention object."""
|
"""Construct an MultiHeadedAttention object."""
|
||||||
super(MultiHeadedAttention, self).__init__()
|
super(MultiHeadedAttention, self).__init__()
|
||||||
@ -317,8 +328,12 @@ class MultiHeadedAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True)
|
self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||||
self.linear_k = nn.Linear(embed_dim, attention_dim, bias=True)
|
self.linear_k = nn.Linear(
|
||||||
self.linear_v = nn.Linear(embed_dim, attention_dim, bias=True)
|
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
|
||||||
|
)
|
||||||
|
self.linear_v = nn.Linear(
|
||||||
|
embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True
|
||||||
|
)
|
||||||
self.scale = math.sqrt(self.head_dim)
|
self.scale = math.sqrt(self.head_dim)
|
||||||
|
|
||||||
self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True)
|
self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True)
|
||||||
@ -475,6 +490,7 @@ def _test_attention_decoder_model():
|
|||||||
attention_dim=512,
|
attention_dim=512,
|
||||||
nhead=8,
|
nhead=8,
|
||||||
feedforward_dim=2048,
|
feedforward_dim=2048,
|
||||||
|
memory_dim=384,
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
sos_id=1,
|
sos_id=1,
|
||||||
eos_id=1,
|
eos_id=1,
|
||||||
@ -485,7 +501,7 @@ def _test_attention_decoder_model():
|
|||||||
print(f"Number of model parameters: {num_param}")
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
m.eval()
|
m.eval()
|
||||||
encoder_out = torch.randn(2, 50, 512)
|
encoder_out = torch.randn(2, 50, 384)
|
||||||
encoder_out_lens = torch.full((2,), 50)
|
encoder_out_lens = torch.full((2,), 50)
|
||||||
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
|
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
|
||||||
|
|
||||||
|
@ -665,9 +665,6 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder_dim = max(_to_int_tuple(params.encoder_dim))
|
|
||||||
assert params.attention_decoder_dim == encoder_dim, (params.attention_decoder_dim, encoder_dim)
|
|
||||||
|
|
||||||
decoder = AttentionDecoderModel(
|
decoder = AttentionDecoderModel(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
decoder_dim=params.attention_decoder_dim,
|
decoder_dim=params.attention_decoder_dim,
|
||||||
@ -675,6 +672,7 @@ def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
attention_dim=params.attention_decoder_attention_dim,
|
attention_dim=params.attention_decoder_attention_dim,
|
||||||
nhead=params.attention_decoder_num_heads,
|
nhead=params.attention_decoder_num_heads,
|
||||||
feedforward_dim=params.attention_decoder_feedforward_dim,
|
feedforward_dim=params.attention_decoder_feedforward_dim,
|
||||||
|
memory_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
sos_id=params.sos_id,
|
sos_id=params.sos_id,
|
||||||
eos_id=params.eos_id,
|
eos_id=params.eos_id,
|
||||||
ignore_id=params.ignore_id,
|
ignore_id=params.ignore_id,
|
||||||
|
@ -1298,6 +1298,7 @@ def rescore_with_attention_decoder_no_ngram(
|
|||||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
|
attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0]
|
||||||
else:
|
else:
|
||||||
attention_scale_list = [attention_scale]
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user