From e8c4d4fe6c1afec65abb3193c085bab9ef8bb1f5 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Mon, 9 Jan 2023 19:44:30 +0900 Subject: [PATCH] from local --- .../ASR/incremental_transf/.conformer.py.swp | Bin 114688 -> 114688 bytes .../ASR/incremental_transf/.model.py.swp | Bin 24576 -> 24576 bytes .../ASR/incremental_transf/model.py | 34 +----------------- 3 files changed, 1 insertion(+), 33 deletions(-) diff --git a/egs/librispeech/ASR/incremental_transf/.conformer.py.swp b/egs/librispeech/ASR/incremental_transf/.conformer.py.swp index c70f2b9218662a177e50d0ad37152e0df086fd7b..4f9620246845a51e8c9dbe6eb823e515c49b6e2c 100644 GIT binary patch delta 36 qcmZo@U~gz(7fUh-^Ym4))H7fJ0s#hwQ}1^t3;t^q+bYKRYCizBMGIU2 delta 36 qcmZo@U~gz(7fUh-^Ym4))H7fJ0s#hwocFtv&-`r^+bYKRYCizD3JbUZ diff --git a/egs/librispeech/ASR/incremental_transf/.model.py.swp b/egs/librispeech/ASR/incremental_transf/.model.py.swp index dd7695fe9428fcecbaa0bd16788e2c9942e1bf16..4c51a91342046780c24aaddcc57d1e22fc0445b4 100644 GIT binary patch delta 100 zcmZoTz}Rqral?CNMuyEFm=kyybv6eIY?kL;!O6f-!o|SAAi%(⋙7Jw&W!s?;;Q% x0OCzRJQIkkfw*MzQG@lIlc!i^3*{E4D%dLI<>~nbyZYo87i(y4{%FOo3IKEz92Wop delta 87 zcmZoTz}Rqral?CN#?Z|lm=kyywKoR}Y?kLe!^yxfg^PiKL4bild$ORxY{`#6-UA@M j1jIXmcs3A&RCI4XYOtPjvZkp#Bg5nptL)AHtr%4SQ_vV= diff --git a/egs/librispeech/ASR/incremental_transf/model.py b/egs/librispeech/ASR/incremental_transf/model.py index cc5fac241..396e99408 100644 --- a/egs/librispeech/ASR/incremental_transf/model.py +++ b/egs/librispeech/ASR/incremental_transf/model.py @@ -237,6 +237,7 @@ class Interformer(nn.Module): self.pt_encoder = pt_encoder self.inter_encoder = inter_encoder + self.mse = nn.MSELoss() def forward( self, @@ -263,37 +264,4 @@ class Interformer(nn.Module): return_grad=True, ) - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - delay_penalty=delay_penalty, - reduction=reduction, - ) - return (simple_loss, pruned_loss)