diff --git a/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp b/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp index 33277a7f7..3449aace7 100644 Binary files a/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp and b/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp differ diff --git a/egs/librispeech/ASR/incremental_transf/.model.py.swp b/egs/librispeech/ASR/incremental_transf/.model.py.swp index abb97f3cf..18b72fcd8 100644 Binary files a/egs/librispeech/ASR/incremental_transf/.model.py.swp and b/egs/librispeech/ASR/incremental_transf/.model.py.swp differ diff --git a/egs/librispeech/ASR/incremental_transf/model.py b/egs/librispeech/ASR/incremental_transf/model.py index 9f9b39769..61ecb08d1 100644 --- a/egs/librispeech/ASR/incremental_transf/model.py +++ b/egs/librispeech/ASR/incremental_transf/model.py @@ -244,11 +244,12 @@ class Interformer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, ): - encoder_out, x_lens, pt_layer_outputs = self.pt_encoder(x, - x_lens, - warmup=warmup, - get_layer_output=True - ) + with torch.no_grad(): + encoder_out, x_lens, pt_layer_outputs = self.pt_encoder(x, + x_lens, + warmup=warmup, + get_layer_output=True + ) inter_layer_outputs = self.inter_encoder(pt_layer_outputs) loss = 0