mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
output from middle layer
This commit is contained in:
parent
8d73423a29
commit
9ee57959ec
@ -18,7 +18,7 @@
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
@ -61,6 +61,7 @@ class Conformer(EncoderInterface):
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
extra_output_layer: int = None, # Default no distillation loss.
|
||||
) -> None:
|
||||
super(Conformer, self).__init__()
|
||||
|
||||
@ -86,7 +87,18 @@ class Conformer(EncoderInterface):
|
||||
layer_dropout,
|
||||
cnn_module_kernel,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
# The last layer is always needed.
|
||||
self.output_layers = [num_encoder_layers - 1]
|
||||
if extra_output_layer is not None:
|
||||
assert (
|
||||
extra_output_layer >= 0
|
||||
and extra_output_layer < num_encoder_layers - 1
|
||||
)
|
||||
self.output_layers.insert(extra_output_layer, 0)
|
||||
self.encoder = ConformerEncoder(
|
||||
encoder_layer, num_encoder_layers, output_layers=self.output_layers
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -119,12 +131,18 @@ class Conformer(EncoderInterface):
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
x = self.encoder(
|
||||
layers_result = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
# layers_result[0] is the output from a middle layer for distillation.
|
||||
# layers_result[-1] is the output from the final layer for RNN-T loss.
|
||||
x = layers_result[-1]
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
if len(self.output_layers) > 1:
|
||||
assert len(self.output_layers) == len(layers_result)
|
||||
return x, lengths, layers_result[0]
|
||||
return x, lengths
|
||||
|
||||
|
||||
@ -276,12 +294,18 @@ class ConformerEncoder(nn.Module):
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
encoder_layer: nn.Module,
|
||||
num_layers: int,
|
||||
output_layers: List[int],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
self.output_layers = output_layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -290,7 +314,7 @@ class ConformerEncoder(nn.Module):
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
) -> List[Tensor]:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
@ -309,6 +333,7 @@ class ConformerEncoder(nn.Module):
|
||||
"""
|
||||
output = src
|
||||
|
||||
layers_result = []
|
||||
for i, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
output,
|
||||
@ -317,8 +342,10 @@ class ConformerEncoder(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
if i in self.output_layers:
|
||||
layers_result.append(output)
|
||||
|
||||
return output
|
||||
return layers_result
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user