diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index 1c8e25262..1400b9161 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -36,7 +36,7 @@ class AttentionDecoderModel(nn.Module): vocab_size (int): Number of classes. decoder_dim: (int,int): embedding dimension of 2 encoder stacks attention_dim: (int,int): attention dimension of 2 encoder stacks - nhead (int, int): number of heads + num_heads (int, int): number of heads dim_feedforward (int, int): feedforward dimension in 2 encoder stacks num_encoder_layers (int): number of encoder layers dropout (float): dropout rate @@ -48,7 +48,7 @@ class AttentionDecoderModel(nn.Module): decoder_dim: int = 512, num_decoder_layers: int = 6, attention_dim: int = 512, - nhead: int = 8, + num_heads: int = 8, feedforward_dim: int = 2048, memory_dim: int = 512, sos_id: int = 1, @@ -69,7 +69,7 @@ class AttentionDecoderModel(nn.Module): d_model=decoder_dim, num_decoder_layers=num_decoder_layers, attention_dim=attention_dim, - nhead=nhead, + num_heads=num_heads, feedforward_dim=feedforward_dim, memory_dim=memory_dim, dropout=dropout, @@ -111,7 +111,12 @@ class AttentionDecoderModel(nn.Module): ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) # decoder forward - decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens) + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) loss = self.loss_fun(x=decoder_out, target=ys_out_pad) return loss @@ -137,7 +142,12 @@ class AttentionDecoderModel(nn.Module): ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) # decoder forward - decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens) + decoder_out = self.decoder( + x=ys_in_pad, + x_lens=ys_in_lens, + memory=encoder_out, + memory_lens=encoder_out_lens, + ) batch_size, _, num_classes = decoder_out.size() nll = nn.functional.cross_entropy( @@ -152,14 +162,13 @@ class AttentionDecoderModel(nn.Module): class TransformerDecoder(nn.Module): """Transfomer decoder module. - It is modified from https://github.com/espnet/espnet/blob/master/espnet2/asr/decoder/transformer_decoder.py. Args: vocab_size: output dim d_model: decoder dimension num_decoder_layers: number of decoder layers attention_dim: total dimension of multi head attention - n_head: number of attention heads + num_heads: number of attention heads feedforward_dim: hidden dimension of feed_forward module dropout: dropout rate """ @@ -170,7 +179,7 @@ class TransformerDecoder(nn.Module): d_model: int = 512, num_decoder_layers: int = 6, attention_dim: int = 512, - nhead: int = 8, + num_heads: int = 8, feedforward_dim: int = 2048, memory_dim: int = 512, dropout: float = 0.1, @@ -178,14 +187,19 @@ class TransformerDecoder(nn.Module): super().__init__() self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) - # Using absolute positional encoding + # Absolute positional encoding self.pos = PositionalEncoding(d_model, dropout_rate=0.1) self.num_layers = num_decoder_layers self.layers = nn.ModuleList( [ DecoderLayer( - d_model, attention_dim, nhead, feedforward_dim, memory_dim, dropout + d_model=d_model, + attention_dim=attention_dim, + num_heads=num_heads, + feedforward_dim=feedforward_dim, + memory_dim=memory_dim, + dropout=dropout, ) for _ in range(num_decoder_layers) ] @@ -195,49 +209,67 @@ class TransformerDecoder(nn.Module): def forward( self, - memory: torch.Tensor, - memory_lens: torch.Tensor, - ys_in_pad: torch.Tensor, - ys_in_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward decoder. - + x: torch.Tensor, + x_lens: torch.Tensor, + memory: Optional[torch.Tensor] = None, + memory_lens: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ Args: - memory: encoded memory, (batch, maxlen_in, feat) - memory_lens: (batch,) - ys_in_pad: input token ids, (batch, maxlen_out) - ys_in_lens: (batch, ) + x: Input tensor of shape (batch, tgt_len). + x_lens: A tensor of shape (batch,) containing the number of tokens in `x` + before padding. + memory: + Memory sequence of shape (batch, src_len, memory_dim). + memory_lens: + A tensor of shape (batch,) containing the number of frames in + `memory` before padding. Returns: - tgt: decoded token score before softmax (batch, maxlen_out, vocab_size) + Decoded token logits before softmax (batch, tgt_len, vocab_size) """ - tgt = ys_in_pad - # tgt_mask: (B, 1, L) - tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device) - # m: (1, L, L) - m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) - # tgt_mask: (B, L, L) - tgt_mask = tgt_mask | (~m) + x = self.embed(x) # (batch, tgt_len, embed_dim) + x = self.pos(x) # (batch, tgt_len, embed_dim) - memory_mask = make_pad_mask(memory_lens)[:, None, :].to(memory.device) + x = x.permute(1, 0, 2) # (tgt_len, batch, embed_dim) - tgt = self.embed(tgt) - tgt = self.pos(tgt) + # construct attn_mask for self-attn modules + padding_mask = make_pad_mask(x_lens) # (batch, tgt_len) + causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) + attn_mask = torch.logical_or( + padding_mask.unsqueeze(1), # (batch, 1, seq_len) + torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len) + ) # (batch, seq_len, seq_len) + + if memory is not None: + memory = memory.permute(1, 0, 2) # (src_len, batch, memory_dim) + # construct memory_attn_mask for cross-attn modules + memory_padding_mask = make_pad_mask(memory_lens) # (batch, src_len) + memory_attn_mask = memory_padding_mask.unsqueeze(1) # (batch, 1, src_len) + else: + memory_attn_mask = None for i, mod in enumerate(self.layers): - tgt = mod(tgt, tgt_mask, memory, memory_mask) + x = mod( + x, + attn_mask=attn_mask, + memory=memory, + memory_attn_mask=memory_attn_mask, + ) - tgt = self.output_layer(tgt) - return tgt + x = x.permute(1, 0, 2) # (batch, tgt_len, vocab_size) + x = self.output_layer(x) + + return x class DecoderLayer(nn.Module): """Single decoder layer module. Args: - d_model: equal to encoder_dim + d_model: equal to decoder_dim, total dimension of the decoder attention_dim: total dimension of multi head attention - n_head: number of attention heads + num_heads: number of attention heads feedforward_dim: hidden dimension of feed_forward module dropout: dropout rate """ @@ -246,7 +278,7 @@ class DecoderLayer(nn.Module): self, d_model: int = 512, attention_dim: int = 512, - nhead: int = 8, + num_heads: int = 8, feedforward_dim: int = 2048, memory_dim: int = 512, dropout: float = 0.1, @@ -255,10 +287,14 @@ class DecoderLayer(nn.Module): super(DecoderLayer, self).__init__() self.norm_self_attn = nn.LayerNorm(d_model) - self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0) + self.self_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, dropout=0.0 + ) self.norm_src_attn = nn.LayerNorm(d_model) - self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, memory_dim=memory_dim, dropout=0.0) + self.src_attn = MultiHeadAttention( + d_model, attention_dim, num_heads, memory_dim=memory_dim, dropout=0.0 + ) self.norm_ff = nn.LayerNorm(d_model) self.feed_forward = nn.Sequential( @@ -270,40 +306,53 @@ class DecoderLayer(nn.Module): self.dropout = nn.Dropout(dropout) - def forward(self, tgt, tgt_mask, memory, memory_mask): - """Compute decoded features. - + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + memory: Optional[torch.Tensor] = None, + memory_attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ Args: - tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). - tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). - memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). - memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). - - Returns: - torch.Tensor: Output tensor(#batch, maxlen_out, size). + x: Input sequence of shape (seq_len, batch, embed_dim). + attn_mask: A binary mask for self-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). + memory: Memory sequence of shape (seq_len, batch, memory_dim). + memory_attn_mask: A binary mask for cross-attention module indicating which + elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). """ # self-attn module - tgt_norm = self.norm_self_attn(tgt) - tgt = tgt + self.dropout(self.self_attn(tgt_norm, tgt_norm, tgt_norm, tgt_mask)) + qkv = self.norm_self_attn(x) + self_attn_out = self.self_attn( + query=qkv, key=qkv, value=qkv, attn_mask=attn_mask + ) + x = x + self.dropout(self_attn_out) # cross-attn module - tgt = tgt + self.dropout(self.src_attn(self.norm_src_attn(tgt), memory, memory, memory_mask)) + q = self.norm_src_attn(x) + src_attn_out = self.src_attn( + query=q, key=memory, value=memory, attn_mask=memory_attn_mask + ) + x = x + self.dropout(src_attn_out) # feed-forward module - tgt = tgt + self.dropout(self.feed_forward(self.norm_ff(tgt))) + x = x + self.dropout(self.feed_forward(self.norm_ff(x))) - return tgt + return x -class MultiHeadedAttention(nn.Module): +class MultiHeadAttention(nn.Module): """Multi-Head Attention layer. Args: embed_dim: total dimension of the model. - attention_dim: dimension in the attention module, may be less or more than embed_dim - but must be a multiple of num_heads. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. + attention_dim: dimension in the attention module, but must be a multiple of num_heads. + num_heads: number of parallel attention heads. + memory_dim: dimension of memory embedding, optional. + dropout: a Dropout layer on attn_output_weights. """ def __init__( @@ -312,20 +361,18 @@ class MultiHeadedAttention(nn.Module): attention_dim: int, num_heads: int, memory_dim: Optional[int] = None, - dropout: float = 0.0 + dropout: float = 0.0, ): - """Construct an MultiHeadedAttention object.""" - super(MultiHeadedAttention, self).__init__() + super(MultiHeadAttention, self).__init__() self.embed_dim = embed_dim self.attention_dim = attention_dim self.num_heads = num_heads - self.dropout = dropout self.head_dim = attention_dim // num_heads assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, + self.head_dim, num_heads, attention_dim ) + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True) self.linear_k = nn.Linear( @@ -334,74 +381,89 @@ class MultiHeadedAttention(nn.Module): self.linear_v = nn.Linear( embed_dim if memory_dim is None else memory_dim, attention_dim, bias=True ) - self.scale = math.sqrt(self.head_dim) self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True) - def forward(self, query, key, value, mask): - """Compute scaled dot product attention. + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute dot product attention. Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). + query: Query tensor of shape (tgt_len, batch, embed_dim). + key: Key tensor of shape (src_len, batch, embed_dim or memory_dim). + value: Value tensor of shape (src_len, batch, embed_dim or memory_dim). + key_padding_mask: A binary mask indicating which elements are padding. + Its shape is (batch, src_len). + attn_mask: A binary mask indicating which elements will be filled with -inf. + Its shape is (batch, 1, src_len) or (batch, tgt_len, src_len). Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - + Output tensor of shape (tgt_len, batch, embed_dim). """ - bsz, tgt_len, _ = query.size() - src_len = key.size(1) num_heads = self.num_heads head_dim = self.head_dim - q = self.linear_q(query) - k = self.linear_k(key) - v = self.linear_v(value) + tgt_len, batch, _ = query.shape + src_len = key.shape[0] - q = q.reshape(bsz, tgt_len, num_heads, head_dim) - q = q.transpose(1, 2) # (batch, head, time1, head_dim) - k = k.reshape(bsz, src_len, num_heads, head_dim) - k = k.permute(0, 2, 3, 1) # (batch, head, head_dim, time2) - v = v.reshape(bsz, src_len, num_heads, head_dim) - v = v.transpose(1, 2).reshape(bsz * num_heads, src_len, head_dim) + q = self.linear_q(query) # (tgt_len, batch, num_heads * head_dim) + k = self.linear_k(key) # (src_len, batch, num_heads * head_dim) + v = self.linear_v(value) # (src_len, batch, num_heads * head_dim) - # (batch, head, time1, time2) - attn_output_weights = torch.matmul(q, k) / self.scale + q = q.reshape(tgt_len, batch, num_heads, head_dim) + q = q.permute(1, 2, 0, 3) # (batch, head, tgt_len, head_dim) + k = k.reshape(src_len, batch, num_heads, head_dim) + k = k.permute(1, 2, 3, 0) # (batch, head, head_dim, src_len) + v = v.reshape(src_len, batch, num_heads, head_dim) + v = v.reshape(src_len, batch * num_heads, head_dim).transpose(0, 1) - # attn_output_weights = torch.matmul(q, k) - # # This is a harder way of limiting the attention scores to not be too large. - # # It incurs a penalty if any of them has an absolute value greater than 50.0. - # # this should be outside the normal range of the attention scores. We use - # # this mechanism instead of, say, a limit on entropy, because once the entropy - # # gets very small gradients through the softmax can become very small, and - # # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=50.0, penalty=1.0e-04 - ) + # Note: could remove the scaling operation when using ScaledAdam + # (batch, head, tgt_len, src_len) + attn_weights = torch.matmul(q, k) / math.sqrt(head_dim) - if mask is not None: - attn_output_weights = attn_output_weights.masked_fill( - mask.unsqueeze(1), float("-inf") + # From zipformer.py: + # This is a harder way of limiting the attention scores to not be too large. + # It incurs a penalty if any of them has an absolute value greater than 50.0. + # this should be outside the normal range of the attention scores. We use + # this mechanism instead of, say, a limit on entropy, because once the entropy + # gets very small gradients through the softmax can become very small, and + # some mechanisms like that become ineffective. + attn_weights = penalize_abs_values_gt(attn_weights, limit=50.0, penalty=1.0e-04) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), ) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + if attn_mask is not None: + assert ( + attn_mask.shape == (batch, 1, src_len) + or attn_mask.shape == (batch, tgt_len, src_len) + ), attn_mask.shape + attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf")) - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=self.dropout, training=self.training + attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training ) - # (bsz * head, time1, head_dim_v) - attn_output = torch.bmm(attn_output_weights, v) - assert attn_output.shape == (bsz * num_heads, tgt_len, head_dim) - attn_output = ( - attn_output.reshape(bsz, num_heads, tgt_len, head_dim) - .transpose(1, 2) - .reshape(bsz, tgt_len, self.attention_dim) - ) + # (batch * head, tgt_len, head_dim) + attn_output = torch.bmm(attn_weights, v) + assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape + + attn_output = attn_output.transpose(0, 1).contiguous() + attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) + + # (batch, tgt_len, embed_dim) attn_output = self.out_proj(attn_output) return attn_output @@ -488,7 +550,7 @@ def _test_attention_decoder_model(): decoder_dim=512, num_decoder_layers=6, attention_dim=512, - nhead=8, + num_heads=8, feedforward_dim=2048, memory_dim=384, dropout=0.1, diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c70b55ba4..704afda9c 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -662,7 +662,7 @@ def get_attention_decoder_model(params: AttributeDict) -> nn.Module: decoder_dim=params.attention_decoder_dim, num_decoder_layers=params.attention_decoder_num_layers, attention_dim=params.attention_decoder_attention_dim, - nhead=params.attention_decoder_num_heads, + num_heads=params.attention_decoder_num_heads, feedforward_dim=params.attention_decoder_feedforward_dim, memory_dim=max(_to_int_tuple(params.encoder_dim)), sos_id=params.sos_id,