mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
moved an operand to the joiner model
This commit is contained in:
parent
056efeef30
commit
739e2a22c6
@ -120,7 +120,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
|
||||
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
||||
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
|
||||
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
|
||||
that are True in this mask will be ignored as sources in the attention weighting.
|
||||
attn_mask: mask of shape (seq_len, seq_len)
|
||||
or (seq_len, batch_size * prune_range, batch_size * prune_range),
|
||||
@ -333,13 +333,15 @@ class AlignmentAttentionModule(nn.Module):
|
||||
|
||||
pos_emb = self.pos_encode(merged_am_pruned)
|
||||
|
||||
attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb)
|
||||
attn_weights = self.cross_attn_weights(
|
||||
merged_lm_pruned, merged_am_pruned, pos_emb
|
||||
)
|
||||
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
|
||||
# (T, batch_size * prune_range, encoder_dim)
|
||||
|
||||
return label_level_am_representation \
|
||||
.reshape(T, batch_size, prune_range, encoder_dim) \
|
||||
.permute(1, 0, 2, 3)
|
||||
return label_level_am_representation.reshape(
|
||||
T, batch_size, prune_range, encoder_dim
|
||||
).permute(1, 0, 2, 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -120,7 +120,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
lm_pruned: input of shape (seq_len, batch_size * prune_range, decoder_embed_dim)
|
||||
am_pruned: input of shape (seq_len, batch_size * prune_range, encoder_embed_dim)
|
||||
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 1, pos_dim)
|
||||
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
|
||||
key_padding_mask: a bool tensor of shape (batch_size * prune_range, seq_len). Positions
|
||||
that are True in this mask will be ignored as sources in the attention weighting.
|
||||
attn_mask: mask of shape (seq_len, seq_len)
|
||||
or (seq_len, batch_size * prune_range, batch_size * prune_range),
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from alignment_attention_module import AlignmentAttentionModule
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
@ -29,6 +30,7 @@ class Joiner(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.label_level_am_attention = AlignmentAttentionModule()
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
@ -52,12 +54,15 @@ class Joiner(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
|
||||
assert encoder_out.ndim == decoder_out.ndim, (
|
||||
encoder_out.shape,
|
||||
decoder_out.shape,
|
||||
)
|
||||
|
||||
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
decoder_out
|
||||
)
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
|
@ -21,7 +21,6 @@ from typing import Optional, Tuple
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from alignment_attention_module import AlignmentAttentionModule
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
@ -35,7 +34,6 @@ class AsrModel(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
label_level_am_attention: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 384,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
@ -113,8 +111,6 @@ class AsrModel(nn.Module):
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
self.label_level_am_attention = label_level_am_attention
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -264,14 +260,11 @@ class AsrModel(nn.Module):
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
|
||||
label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False)
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
|
@ -603,16 +603,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
)
|
||||
return joiner
|
||||
|
||||
def get_attn_module(params: AttributeDict) -> nn.Module:
|
||||
attn_module = AlignmentAttentionModule()
|
||||
return attn_module
|
||||
|
||||
def get_model(params: AttributeDict) -> nn.Module:
|
||||
assert (
|
||||
params.use_transducer or params.use_ctc
|
||||
), (f"At least one of them should be True, "
|
||||
assert params.use_transducer or params.use_ctc, (
|
||||
f"At least one of them should be True, "
|
||||
f"but got params.use_transducer={params.use_transducer}, "
|
||||
f"params.use_ctc={params.use_ctc}")
|
||||
f"params.use_ctc={params.use_ctc}"
|
||||
)
|
||||
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
encoder = get_encoder_model(params)
|
||||
@ -624,14 +621,11 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = None
|
||||
joiner = None
|
||||
|
||||
attn = get_attn_module(params)
|
||||
|
||||
model = AsrModel(
|
||||
encoder_embed=encoder_embed,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
label_level_am_attention=attn,
|
||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
decoder_dim=params.decoder_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -815,17 +809,16 @@ def compute_loss(
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
# to params.simple_loss scale by warm_step.
|
||||
simple_loss_scale = (
|
||||
s if batch_idx_train >= warm_step
|
||||
s
|
||||
if batch_idx_train >= warm_step
|
||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||
)
|
||||
pruned_loss_scale = (
|
||||
1.0 if batch_idx_train >= warm_step
|
||||
1.0
|
||||
if batch_idx_train >= warm_step
|
||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||
)
|
||||
loss += (
|
||||
simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
if params.use_ctc:
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
|
Loading…
x
Reference in New Issue
Block a user