moved an operand to the joiner model

This commit is contained in:
zr_jin 2023-07-23 22:30:28 +08:00
parent 056efeef30
commit 739e2a22c6
5 changed files with 27 additions and 34 deletions

View File

@ -120,7 +120,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
that are True in this mask will be ignored as sources in the attention weighting. that are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) attn_mask: mask of shape (seq_len, seq_len)
or (seq_len, batch_size * prune_range, batch_size * prune_range), or (seq_len, batch_size * prune_range, batch_size * prune_range),
@ -333,13 +333,15 @@ class AlignmentAttentionModule(nn.Module):
pos_emb = self.pos_encode(merged_am_pruned) pos_emb = self.pos_encode(merged_am_pruned)
attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb) attn_weights = self.cross_attn_weights(
merged_lm_pruned, merged_am_pruned, pos_emb
)
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights) label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
# (T, batch_size * prune_range, encoder_dim) # (T, batch_size * prune_range, encoder_dim)
return label_level_am_representation \ return label_level_am_representation.reshape(
.reshape(T, batch_size, prune_range, encoder_dim) \ T, batch_size, prune_range, encoder_dim
.permute(1, 0, 2, 3) ).permute(1, 0, 2, 3)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -120,7 +120,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim) lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim) am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
that are True in this mask will be ignored as sources in the attention weighting. that are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) attn_mask: mask of shape (seq_len, seq_len)
or (seq_len, batch_size * prune_range, batch_size * prune_range), or (seq_len, batch_size * prune_range, batch_size * prune_range),

View File

@ -16,6 +16,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from scaling import ScaledLinear from scaling import ScaledLinear
@ -29,6 +30,7 @@ class Joiner(nn.Module):
): ):
super().__init__() super().__init__()
self.label_level_am_attention = AlignmentAttentionModule()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size) self.output_linear = nn.Linear(joiner_dim, vocab_size)
@ -52,12 +54,15 @@ class Joiner(nn.Module):
Returns: Returns:
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj( logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
decoder_out
)
else: else:
logit = encoder_out + decoder_out logit = encoder_out + decoder_out

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
@ -35,7 +34,6 @@ class AsrModel(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: Optional[nn.Module] = None, decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None,
label_level_am_attention: Optional[nn.Module] = None,
encoder_dim: int = 384, encoder_dim: int = 384,
decoder_dim: int = 512, decoder_dim: int = 512,
vocab_size: int = 500, vocab_size: int = 500,
@ -113,8 +111,6 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1), nn.LogSoftmax(dim=-1),
) )
self.label_level_am_attention = label_level_am_attention
def forward_encoder( def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -264,14 +260,11 @@ class AsrModel(nn.Module):
ranges=ranges, ranges=ranges,
) )
label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned)
# logits : [B, T, prune_range, vocab_size] # logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(

View File

@ -603,16 +603,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
) )
return joiner return joiner
def get_attn_module(params: AttributeDict) -> nn.Module:
attn_module = AlignmentAttentionModule()
return attn_module
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert ( assert params.use_transducer or params.use_ctc, (
params.use_transducer or params.use_ctc f"At least one of them should be True, "
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -624,14 +621,11 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None decoder = None
joiner = None joiner = None
attn = get_attn_module(params)
model = AsrModel( model = AsrModel(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
label_level_am_attention=attn,
encoder_dim=max(_to_int_tuple(params.encoder_dim)), encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -815,17 +809,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += ( loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss