From 4dc972899b60fe2ff05744bdd2709b3b2599cc0d Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Tue, 21 Nov 2023 11:59:01 +0800 Subject: [PATCH] fix memory-dim issues --- .../ASR/zipformer/attention_decoder.py | 30 ++++++++++++++----- egs/librispeech/ASR/zipformer/train.py | 4 +-- icefall/decode.py | 1 + 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index 6c6cabec5..1c8e25262 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -19,7 +19,7 @@ # https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py import math -from typing import List, Tuple +from typing import List, Optional, Tuple import k2 import torch @@ -50,6 +50,7 @@ class AttentionDecoderModel(nn.Module): attention_dim: int = 512, nhead: int = 8, feedforward_dim: int = 2048, + memory_dim: int = 512, sos_id: int = 1, eos_id: int = 1, dropout: float = 0.1, @@ -70,6 +71,7 @@ class AttentionDecoderModel(nn.Module): attention_dim=attention_dim, nhead=nhead, feedforward_dim=feedforward_dim, + memory_dim=memory_dim, dropout=dropout, ) @@ -170,6 +172,7 @@ class TransformerDecoder(nn.Module): attention_dim: int = 512, nhead: int = 8, feedforward_dim: int = 2048, + memory_dim: int = 512, dropout: float = 0.1, ): super().__init__() @@ -181,7 +184,9 @@ class TransformerDecoder(nn.Module): self.num_layers = num_decoder_layers self.layers = nn.ModuleList( [ - DecoderLayer(d_model, attention_dim, nhead, feedforward_dim, dropout) + DecoderLayer( + d_model, attention_dim, nhead, feedforward_dim, memory_dim, dropout + ) for _ in range(num_decoder_layers) ] ) @@ -243,6 +248,7 @@ class DecoderLayer(nn.Module): attention_dim: int = 512, nhead: int = 8, feedforward_dim: int = 2048, + memory_dim: int = 512, dropout: float = 0.1, ): """Construct an DecoderLayer object.""" @@ -252,7 +258,7 @@ class DecoderLayer(nn.Module): self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0) self.norm_src_attn = nn.LayerNorm(d_model) - self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0) + self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, memory_dim=memory_dim, dropout=0.0) self.norm_ff = nn.LayerNorm(d_model) self.feed_forward = nn.Sequential( @@ -301,7 +307,12 @@ class MultiHeadedAttention(nn.Module): """ def __init__( - self, embed_dim: int, attention_dim: int, num_heads: int, dropout: float = 0.0 + self, + embed_dim: int, + attention_dim: int, + num_heads: int, + memory_dim: Optional[int] = None, + dropout: float = 0.0 ): """Construct an MultiHeadedAttention object.""" super(MultiHeadedAttention, self).__init__() @@ -317,8 +328,12 @@ class MultiHeadedAttention(nn.Module): ) self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True) - self.linear_k = nn.Linear(embed_dim, attention_dim, bias=True) - self.linear_v = nn.Linear(embed_dim, attention_dim, bias=True) + self.linear_k = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) + self.linear_v = nn.Linear( + embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True + ) self.scale = math.sqrt(self.head_dim) self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True) @@ -475,6 +490,7 @@ def _test_attention_decoder_model(): attention_dim=512, nhead=8, feedforward_dim=2048, + memory_dim=384, dropout=0.1, sos_id=1, eos_id=1, @@ -485,7 +501,7 @@ def _test_attention_decoder_model(): print(f"Number of model parameters: {num_param}") m.eval() - encoder_out = torch.randn(2, 50, 512) + encoder_out = torch.randn(2, 50, 384) encoder_out_lens = torch.full((2,), 50) token_ids = [[1, 2, 3, 4], [2, 3, 10]] diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index d0f25ed01..f81c6beac 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -665,9 +665,6 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_attention_decoder_model(params: AttributeDict) -> nn.Module: - encoder_dim = max(_to_int_tuple(params.encoder_dim)) - assert params.attention_decoder_dim == encoder_dim, (params.attention_decoder_dim, encoder_dim) - decoder = AttentionDecoderModel( vocab_size=params.vocab_size, decoder_dim=params.attention_decoder_dim, @@ -675,6 +672,7 @@ def get_attention_decoder_model(params: AttributeDict) -> nn.Module: attention_dim=params.attention_decoder_attention_dim, nhead=params.attention_decoder_num_heads, feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), sos_id=params.sos_id, eos_id=params.eos_id, ignore_id=params.ignore_id, diff --git a/icefall/decode.py b/icefall/decode.py index 1d0991d87..3abd5648a 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1298,6 +1298,7 @@ def rescore_with_attention_decoder_no_ngram( attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0] else: attention_scale_list = [attention_scale]