extract middle layer memory embedding

This commit is contained in:
Guo Liyong 2021-12-31 15:38:53 +08:00
parent e31c14b335
commit 6bb949eb5a

View File

@ -373,6 +373,7 @@ class ConformerEncoder(nn.TransformerEncoder):
pos_emb: Tensor, pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
mid_layer_list: List[int] = None,
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -381,6 +382,8 @@ class ConformerEncoder(nn.TransformerEncoder):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional). mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
mid_layer_list: layers to tract memory embeddings,
which will be feeded into codebook loss modules.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
@ -390,20 +393,30 @@ class ConformerEncoder(nn.TransformerEncoder):
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
""" """
if mid_layer_list is not None:
mid_mem_embeddings = []
output = src output = src
for mod in self.layers: for mod_idx, mod in enumerate(self.layers):
output = mod( output = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
if mod_idx in mid_layer_list:
mid_mem_embeddings.append(output)
if self.norm is not None: if self.norm is not None:
output = self.norm(output) output = self.norm(output)
return output if mid_layer_list is not None:
assert len(mid_layer_list) == len(mid_mem_embeddings)
return output, mid_mem_embeddings
else:
return output
class RelPositionalEncoding(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module):