moved an operand to the joiner model

This commit is contained in:
zr_jin 2023-07-23 22:30:28 +08:00
parent 056efeef30
commit 739e2a22c6
5 changed files with 27 additions and 34 deletions

View File

@ -120,7 +120,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
that are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len)
or (seq_len, batch_size * prune_range, batch_size * prune_range),
@ -333,13 +333,15 @@ class AlignmentAttentionModule(nn.Module):
pos_emb = self.pos_encode(merged_am_pruned)
attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb)
attn_weights = self.cross_attn_weights(
merged_lm_pruned, merged_am_pruned, pos_emb
)
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
# (T, batch_size * prune_range, encoder_dim)
return label_level_am_representation \
.reshape(T, batch_size, prune_range, encoder_dim) \
.permute(1, 0, 2, 3)
return label_level_am_representation.reshape(
T, batch_size, prune_range, encoder_dim
).permute(1, 0, 2, 3)
if __name__ == "__main__":

View File

@ -120,7 +120,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
that are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len)
or (seq_len, batch_size * prune_range, batch_size * prune_range),

View File

@ -16,6 +16,7 @@
import torch
import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from scaling import ScaledLinear
@ -29,6 +30,7 @@ class Joiner(nn.Module):
):
super().__init__()
self.label_level_am_attention = AlignmentAttentionModule()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
@ -52,12 +54,15 @@ class Joiner(nn.Module):
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
@ -35,7 +34,6 @@ class AsrModel(nn.Module):
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
label_level_am_attention: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
@ -113,8 +111,6 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1),
)
self.label_level_am_attention = label_level_am_attention
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -264,14 +260,11 @@ class AsrModel(nn.Module):
ranges=ranges,
)
label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False)
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(

View File

@ -603,16 +603,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
)
return joiner
def get_attn_module(params: AttributeDict) -> nn.Module:
attn_module = AlignmentAttentionModule()
return attn_module
def get_model(params: AttributeDict) -> nn.Module:
assert (
params.use_transducer or params.use_ctc
), (f"At least one of them should be True, "
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}")
f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
@ -624,14 +621,11 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None
joiner = None
attn = get_attn_module(params)
model = AsrModel(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
joiner=joiner,
label_level_am_attention=attn,
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
@ -815,17 +809,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += (
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss