diff --git a/egs/librispeech/ASR/incremental_transf/.conformer.py.swp b/egs/librispeech/ASR/incremental_transf/.conformer.py.swp index e89a7c65a..716a29e38 100644 Binary files a/egs/librispeech/ASR/incremental_transf/.conformer.py.swp and b/egs/librispeech/ASR/incremental_transf/.conformer.py.swp differ diff --git a/egs/librispeech/ASR/incremental_transf/.model.py.swp b/egs/librispeech/ASR/incremental_transf/.model.py.swp index 0342b02ce..93db73def 100644 Binary files a/egs/librispeech/ASR/incremental_transf/.model.py.swp and b/egs/librispeech/ASR/incremental_transf/.model.py.swp differ diff --git a/egs/librispeech/ASR/incremental_transf/model.py b/egs/librispeech/ASR/incremental_transf/model.py index 396e99408..29f520300 100644 --- a/egs/librispeech/ASR/incremental_transf/model.py +++ b/egs/librispeech/ASR/incremental_transf/model.py @@ -244,13 +244,19 @@ class Interformer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ): - encoder_out, x_lens, layer_outputs = self.pt_encoder(x, + encoder_out, x_lens, pt_layer_outputs = self.pt_encoder(x, x_lens, warmup=warmup, get_layer_output=True ) + inter_layer_outputs = self.inter_encoder(pt_layer_outputs) + loss = 0 + with torch.cuda.amp.autocast(enabled=False): + for inter_output, pt_output in zip(inter_layer_outputs, pt_layer_outputs): + mse_loss = self.mse( + ''' simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -263,5 +269,6 @@ class Interformer(nn.Module): delay_penalty=delay_penalty, return_grad=True, ) + ''' return (simple_loss, pruned_loss)