extract middle layer memory embedding

This commit is contained in:
Guo Liyong 2021-12-31 15:38:53 +08:00
parent e31c14b335
commit 6bb949eb5a

View File

@ -373,6 +373,7 @@ class ConformerEncoder(nn.TransformerEncoder):
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
mid_layer_list: List[int] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -381,6 +382,8 @@ class ConformerEncoder(nn.TransformerEncoder):
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
mid_layer_list: layers to tract memory embeddings,
which will be feeded into codebook loss modules.
Shape:
src: (S, N, E).
@ -390,20 +393,30 @@ class ConformerEncoder(nn.TransformerEncoder):
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
if mid_layer_list is not None:
mid_mem_embeddings = []
output = src
for mod in self.layers:
for mod_idx, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
if mod_idx in mid_layer_list:
mid_mem_embeddings.append(output)
if self.norm is not None:
output = self.norm(output)
return output
if mid_layer_list is not None:
assert len(mid_layer_list) == len(mid_mem_embeddings)
return output, mid_mem_embeddings
else:
return output
class RelPositionalEncoding(torch.nn.Module):