This commit is contained in:
zr_jin 2023-07-23 18:13:30 +08:00
parent ab9affd3e5
commit 7824789a52

View File

@ -112,6 +112,8 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1),
)
self.label_level_am_attention = AlignmentAttentionModule()
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@ -261,8 +263,8 @@ class AsrModel(nn.Module):
ranges=ranges,
)
label_level_am_attention = AlignmentAttentionModule()
label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned)
label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned)
# logits : [B, T, prune_range, vocab_size]