diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index db17f4f0d..7e2d7fadd 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -373,6 +373,7 @@ class ConformerEncoder(nn.TransformerEncoder): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + mid_layer_list: List[int] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -381,6 +382,8 @@ class ConformerEncoder(nn.TransformerEncoder): pos_emb: Positional embedding tensor (required). mask: the mask for the src sequence (optional). src_key_padding_mask: the mask for the src keys per batch (optional). + mid_layer_list: layers to tract memory embeddings, + which will be feeded into codebook loss modules. Shape: src: (S, N, E). @@ -390,20 +393,30 @@ class ConformerEncoder(nn.TransformerEncoder): S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number """ + if mid_layer_list is not None: + mid_mem_embeddings = [] + output = src - for mod in self.layers: + for mod_idx, mod in enumerate(self.layers): output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) + if mod_idx in mid_layer_list: + mid_mem_embeddings.append(output) if self.norm is not None: output = self.norm(output) - return output + if mid_layer_list is not None: + assert len(mid_layer_list) == len(mid_mem_embeddings) + return output, mid_mem_embeddings + + else: + return output class RelPositionalEncoding(torch.nn.Module):