from local

This commit is contained in:
dohe0342 2023-01-09 19:44:30 +09:00
parent d61e27625f
commit e8c4d4fe6c
3 changed files with 1 additions and 33 deletions

View File

@ -237,6 +237,7 @@ class Interformer(nn.Module):
self.pt_encoder = pt_encoder self.pt_encoder = pt_encoder
self.inter_encoder = inter_encoder self.inter_encoder = inter_encoder
self.mse = nn.MSELoss()
def forward( def forward(
self, self,
@ -263,37 +264,4 @@ class Interformer(nn.Module):
return_grad=True, return_grad=True,
) )
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss)