mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
d61e27625f
commit
e8c4d4fe6c
Binary file not shown.
Binary file not shown.
@ -237,6 +237,7 @@ class Interformer(nn.Module):
|
|||||||
|
|
||||||
self.pt_encoder = pt_encoder
|
self.pt_encoder = pt_encoder
|
||||||
self.inter_encoder = inter_encoder
|
self.inter_encoder = inter_encoder
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -263,37 +264,4 @@ class Interformer(nn.Module):
|
|||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ranges : [B, T, prune_range]
|
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
|
||||||
px_grad=px_grad,
|
|
||||||
py_grad=py_grad,
|
|
||||||
boundary=boundary,
|
|
||||||
s_range=prune_range,
|
|
||||||
)
|
|
||||||
|
|
||||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
|
||||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
|
||||||
am=self.joiner.encoder_proj(encoder_out),
|
|
||||||
lm=self.joiner.decoder_proj(decoder_out),
|
|
||||||
ranges=ranges,
|
|
||||||
)
|
|
||||||
|
|
||||||
# logits : [B, T, prune_range, vocab_size]
|
|
||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
|
||||||
logits=logits.float(),
|
|
||||||
symbols=y_padded,
|
|
||||||
ranges=ranges,
|
|
||||||
termination_symbol=blank_id,
|
|
||||||
boundary=boundary,
|
|
||||||
delay_penalty=delay_penalty,
|
|
||||||
reduction=reduction,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user