moved an operand to the joiner model

This commit is contained in:
zr_jin 2023-07-23 22:30:28 +08:00
parent 056efeef30
commit 739e2a22c6
5 changed files with 27 additions and 34 deletions

View File

@ -333,13 +333,15 @@ class AlignmentAttentionModule(nn.Module):
pos_emb = self.pos_encode(merged_am_pruned) pos_emb = self.pos_encode(merged_am_pruned)
attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb) attn_weights = self.cross_attn_weights(
merged_lm_pruned, merged_am_pruned, pos_emb
)
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights) label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
# (T, batch_size * prune_range, encoder_dim) # (T, batch_size * prune_range, encoder_dim)
return label_level_am_representation \ return label_level_am_representation.reshape(
.reshape(T, batch_size, prune_range, encoder_dim) \ T, batch_size, prune_range, encoder_dim
.permute(1, 0, 2, 3) ).permute(1, 0, 2, 3)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -16,6 +16,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from scaling import ScaledLinear from scaling import ScaledLinear
@ -29,6 +30,7 @@ class Joiner(nn.Module):
): ):
super().__init__() super().__init__()
self.label_level_am_attention = AlignmentAttentionModule()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size) self.output_linear = nn.Linear(joiner_dim, vocab_size)
@ -52,12 +54,15 @@ class Joiner(nn.Module):
Returns: Returns:
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape) assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj( logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
decoder_out
)
else: else:
logit = encoder_out + decoder_out logit = encoder_out + decoder_out

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from alignment_attention_module import AlignmentAttentionModule
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
@ -35,7 +34,6 @@ class AsrModel(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: Optional[nn.Module] = None, decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None,
label_level_am_attention: Optional[nn.Module] = None,
encoder_dim: int = 384, encoder_dim: int = 384,
decoder_dim: int = 512, decoder_dim: int = 512,
vocab_size: int = 500, vocab_size: int = 500,
@ -113,8 +111,6 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1), nn.LogSoftmax(dim=-1),
) )
self.label_level_am_attention = label_level_am_attention
def forward_encoder( def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -264,14 +260,11 @@ class AsrModel(nn.Module):
ranges=ranges, ranges=ranges,
) )
label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned)
# logits : [B, T, prune_range, vocab_size] # logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections # project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed). # prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False) logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(

View File

@ -603,16 +603,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
) )
return joiner return joiner
def get_attn_module(params: AttributeDict) -> nn.Module:
attn_module = AlignmentAttentionModule()
return attn_module
def get_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert ( assert params.use_transducer or params.use_ctc, (
params.use_transducer or params.use_ctc f"At least one of them should be True, "
), (f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
@ -624,14 +621,11 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None decoder = None
joiner = None joiner = None
attn = get_attn_module(params)
model = AsrModel( model = AsrModel(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
label_level_am_attention=attn,
encoder_dim=max(_to_int_tuple(params.encoder_dim)), encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -815,17 +809,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
simple_loss_scale = ( simple_loss_scale = (
s if batch_idx_train >= warm_step s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
) )
pruned_loss_scale = ( pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step 1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += ( loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss