output from middle layer

This commit is contained in:
Guo Liyong 2022-04-27 11:20:19 +08:00
parent 8d73423a29
commit 9ee57959ec

View File

@ -18,7 +18,7 @@
import copy
import math
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
@ -61,6 +61,7 @@ class Conformer(EncoderInterface):
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
extra_output_layer: int = None, # Default no distillation loss.
) -> None:
super(Conformer, self).__init__()
@ -86,7 +87,18 @@ class Conformer(EncoderInterface):
layer_dropout,
cnn_module_kernel,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
# The last layer is always needed.
self.output_layers = [num_encoder_layers - 1]
if extra_output_layer is not None:
assert (
extra_output_layer >= 0
and extra_output_layer < num_encoder_layers - 1
)
self.output_layers.insert(extra_output_layer, 0)
self.encoder = ConformerEncoder(
encoder_layer, num_encoder_layers, output_layers=self.output_layers
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -119,12 +131,18 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(
layers_result = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
# layers_result[0] is the output from a middle layer for distillation.
# layers_result[-1] is the output from the final layer for RNN-T loss.
x = layers_result[-1]
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
if len(self.output_layers) > 1:
assert len(self.output_layers) == len(layers_result)
return x, lengths, layers_result[0]
return x, lengths
@ -276,12 +294,18 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb)
"""
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
output_layers: List[int],
) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.output_layers = output_layers
def forward(
self,
@ -290,7 +314,7 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
) -> List[Tensor]:
r"""Pass the input through the encoder layers in turn.
Args:
@ -309,6 +333,7 @@ class ConformerEncoder(nn.Module):
"""
output = src
layers_result = []
for i, mod in enumerate(self.layers):
output = mod(
output,
@ -317,8 +342,10 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
if i in self.output_layers:
layers_result.append(output)
return output
return layers_result
class RelPositionalEncoding(torch.nn.Module):