mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
extract middle layer memory embedding
This commit is contained in:
parent
e31c14b335
commit
6bb949eb5a
@ -373,6 +373,7 @@ class ConformerEncoder(nn.TransformerEncoder):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
mid_layer_list: List[int] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -381,6 +382,8 @@ class ConformerEncoder(nn.TransformerEncoder):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
mask: the mask for the src sequence (optional).
|
mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
mid_layer_list: layers to tract memory embeddings,
|
||||||
|
which will be feeded into codebook loss modules.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
@ -390,20 +393,30 @@ class ConformerEncoder(nn.TransformerEncoder):
|
|||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if mid_layer_list is not None:
|
||||||
|
mid_mem_embeddings = []
|
||||||
|
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
for mod in self.layers:
|
for mod_idx, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
if mod_idx in mid_layer_list:
|
||||||
|
mid_mem_embeddings.append(output)
|
||||||
|
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
output = self.norm(output)
|
output = self.norm(output)
|
||||||
|
|
||||||
return output
|
if mid_layer_list is not None:
|
||||||
|
assert len(mid_layer_list) == len(mid_mem_embeddings)
|
||||||
|
return output, mid_mem_embeddings
|
||||||
|
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user