mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
fixed
This commit is contained in:
parent
ab9affd3e5
commit
7824789a52
@ -112,6 +112,8 @@ class AsrModel(nn.Module):
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
self.label_level_am_attention = AlignmentAttentionModule()
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -261,8 +263,8 @@ class AsrModel(nn.Module):
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
label_level_am_attention = AlignmentAttentionModule()
|
||||
label_level_am_pruned = label_level_am_attention(am_pruned, lm_pruned)
|
||||
|
||||
label_level_am_pruned = self.label_level_am_attention(am_pruned, lm_pruned)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user