mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
moved an operand to the joiner model
This commit is contained in:
parent
056efeef30
commit
739e2a22c6
@ -333,13 +333,15 @@ class AlignmentAttentionModule(nn.Module):
|
|||||||
|
|
||||||
pos_emb = self.pos_encode(merged_am_pruned)
|
pos_emb = self.pos_encode(merged_am_pruned)
|
||||||
|
|
||||||
attn_weights = self.cross_attn_weights(merged_lm_pruned, merged_am_pruned, pos_emb)
|
attn_weights = self.cross_attn_weights(
|
||||||
|
merged_lm_pruned, merged_am_pruned, pos_emb
|
||||||
|
)
|
||||||
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
|
label_level_am_representation = self.cross_attn(merged_am_pruned, attn_weights)
|
||||||
# (T, batch_size * prune_range, encoder_dim)
|
# (T, batch_size * prune_range, encoder_dim)
|
||||||
|
|
||||||
return label_level_am_representation \
|
return label_level_am_representation.reshape(
|
||||||
.reshape(T, batch_size, prune_range, encoder_dim) \
|
T, batch_size, prune_range, encoder_dim
|
||||||
.permute(1, 0, 2, 3)
|
).permute(1, 0, 2, 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from alignment_attention_module import AlignmentAttentionModule
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ class Joiner(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.label_level_am_attention = AlignmentAttentionModule()
|
||||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||||
@ -52,12 +54,15 @@ class Joiner(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
|
assert encoder_out.ndim == decoder_out.ndim, (
|
||||||
|
encoder_out.shape,
|
||||||
|
decoder_out.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = self.label_level_am_attention(encoder_out, decoder_out)
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
decoder_out
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ from typing import Optional, Tuple
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from alignment_attention_module import AlignmentAttentionModule
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
@ -35,7 +34,6 @@ class AsrModel(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: Optional[nn.Module] = None,
|
decoder: Optional[nn.Module] = None,
|
||||||
joiner: Optional[nn.Module] = None,
|
joiner: Optional[nn.Module] = None,
|
||||||
label_level_am_attention: Optional[nn.Module] = None,
|
|
||||||
encoder_dim: int = 384,
|
encoder_dim: int = 384,
|
||||||
decoder_dim: int = 512,
|
decoder_dim: int = 512,
|
||||||
vocab_size: int = 500,
|
vocab_size: int = 500,
|
||||||
@ -113,8 +111,6 @@ class AsrModel(nn.Module):
|
|||||||
nn.LogSoftmax(dim=-1),
|
nn.LogSoftmax(dim=-1),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.label_level_am_attention = label_level_am_attention
|
|
||||||
|
|
||||||
def forward_encoder(
|
def forward_encoder(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -264,14 +260,11 @@ class AsrModel(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned)
|
|
||||||
|
|
||||||
# logits : [B, T, prune_range, vocab_size]
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(label_level_am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
@ -603,16 +603,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
def get_attn_module(params: AttributeDict) -> nn.Module:
|
|
||||||
attn_module = AlignmentAttentionModule()
|
|
||||||
return attn_module
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
params.use_transducer or params.use_ctc
|
f"At least one of them should be True, "
|
||||||
), (f"At least one of them should be True, "
|
|
||||||
f"but got params.use_transducer={params.use_transducer}, "
|
f"but got params.use_transducer={params.use_transducer}, "
|
||||||
f"params.use_ctc={params.use_ctc}")
|
f"params.use_ctc={params.use_ctc}"
|
||||||
|
)
|
||||||
|
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
@ -624,14 +621,11 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder = None
|
decoder = None
|
||||||
joiner = None
|
joiner = None
|
||||||
|
|
||||||
attn = get_attn_module(params)
|
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
label_level_am_attention=attn,
|
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -815,17 +809,16 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
loss += (
|
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss
|
|
||||||
+ pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user