mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
extract middle layer memory embedding
This commit is contained in:
parent
e31c14b335
commit
6bb949eb5a
@ -373,6 +373,7 @@ class ConformerEncoder(nn.TransformerEncoder):
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
mid_layer_list: List[int] = None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -381,6 +382,8 @@ class ConformerEncoder(nn.TransformerEncoder):
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
mid_layer_list: layers to tract memory embeddings,
|
||||
which will be feeded into codebook loss modules.
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
@ -390,20 +393,30 @@ class ConformerEncoder(nn.TransformerEncoder):
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
if mid_layer_list is not None:
|
||||
mid_mem_embeddings = []
|
||||
|
||||
output = src
|
||||
|
||||
for mod in self.layers:
|
||||
for mod_idx, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
if mod_idx in mid_layer_list:
|
||||
mid_mem_embeddings.append(output)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
if mid_layer_list is not None:
|
||||
assert len(mid_layer_list) == len(mid_mem_embeddings)
|
||||
return output, mid_mem_embeddings
|
||||
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user