From 9ee57959ec6e66515d536eacb914c40d6e6b6a10 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Wed, 27 Apr 2022 11:20:19 +0800 Subject: [PATCH] output from middle layer --- .../conformer.py | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/conformer.py index 257936b59..e71911c35 100644 --- a/egs/librispeech/ASR/vq_pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/vq_pruned_transducer_stateless2/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from encoder_interface import EncoderInterface @@ -61,6 +61,7 @@ class Conformer(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, + extra_output_layer: int = None, # Default no distillation loss. ) -> None: super(Conformer, self).__init__() @@ -86,7 +87,18 @@ class Conformer(EncoderInterface): layer_dropout, cnn_module_kernel, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + + # The last layer is always needed. + self.output_layers = [num_encoder_layers - 1] + if extra_output_layer is not None: + assert ( + extra_output_layer >= 0 + and extra_output_layer < num_encoder_layers - 1 + ) + self.output_layers.insert(extra_output_layer, 0) + self.encoder = ConformerEncoder( + encoder_layer, num_encoder_layers, output_layers=self.output_layers + ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 @@ -119,12 +131,18 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder( + layers_result = self.encoder( x, pos_emb, src_key_padding_mask=mask, warmup=warmup ) # (T, N, C) + # layers_result[0] is the output from a middle layer for distillation. + # layers_result[-1] is the output from the final layer for RNN-T loss. + x = layers_result[-1] x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + if len(self.output_layers) > 1: + assert len(self.output_layers) == len(layers_result) + return x, lengths, layers_result[0] return x, lengths @@ -276,12 +294,18 @@ class ConformerEncoder(nn.Module): >>> out = conformer_encoder(src, pos_emb) """ - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + output_layers: List[int], + ) -> None: super().__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers + self.output_layers = output_layers def forward( self, @@ -290,7 +314,7 @@ class ConformerEncoder(nn.Module): mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, warmup: float = 1.0, - ) -> Tensor: + ) -> List[Tensor]: r"""Pass the input through the encoder layers in turn. Args: @@ -309,6 +333,7 @@ class ConformerEncoder(nn.Module): """ output = src + layers_result = [] for i, mod in enumerate(self.layers): output = mod( output, @@ -317,8 +342,10 @@ class ConformerEncoder(nn.Module): src_key_padding_mask=src_key_padding_mask, warmup=warmup, ) + if i in self.output_layers: + layers_result.append(output) - return output + return layers_result class RelPositionalEncoding(torch.nn.Module):