From 5e4b1e01fe20132d990ad76d668cd0c632dc86ac Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 21 Feb 2022 19:56:45 +0800 Subject: [PATCH] Get the input for the auxiliary branch. --- .../transducer_stateless_aux_kl/conformer.py | 48 +++++++++++--- .../subsampling.py | 1 + .../test_conformer.py | 62 +++++++++++++++++++ .../ASR/transducer_stateless_aux_kl/train.py | 4 ++ 4 files changed, 108 insertions(+), 7 deletions(-) create mode 120000 egs/librispeech/ASR/transducer_stateless_aux_kl/subsampling.py create mode 100644 egs/librispeech/ASR/transducer_stateless_aux_kl/test_conformer.py diff --git a/egs/librispeech/ASR/transducer_stateless_aux_kl/conformer.py b/egs/librispeech/ASR/transducer_stateless_aux_kl/conformer.py index 81d7708f9..bea19a20f 100644 --- a/egs/librispeech/ASR/transducer_stateless_aux_kl/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless_aux_kl/conformer.py @@ -52,6 +52,7 @@ class Conformer(Transformer): nhead: int = 4, dim_feedforward: int = 2048, num_encoder_layers: int = 12, + mid_layer: int = 6, dropout: float = 0.1, cnn_module_kernel: int = 31, normalize_before: bool = True, @@ -69,6 +70,7 @@ class Conformer(Transformer): normalize_before=normalize_before, vgg_frontend=vgg_frontend, ) + assert 0 <= mid_layer < num_encoder_layers self.encoder_pos = RelPositionalEncoding(d_model, dropout) @@ -80,14 +82,18 @@ class Conformer(Transformer): cnn_module_kernel, normalize_before, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.encoder = ConformerEncoder( + encoder_layer, num_encoder_layers, mid_layer=mid_layer + ) self.normalize_before = normalize_before if self.normalize_before: self.after_norm = nn.LayerNorm(d_model) + self.mid_layer_norm = nn.LayerNorm(d_model) else: # Note: TorchScript detects that self.after_norm could be used inside forward() # and throws an error without this change. self.after_norm = identity + self.mid_layer_norm = identity def forward( self, x: torch.Tensor, x_lens: torch.Tensor @@ -114,15 +120,24 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) + # Both x and mid_layer_out are of shape (T, N, d_model) + x, mid_layer_out = self.encoder( + x, + pos_emb, + src_key_padding_mask=mask, + ) if self.normalize_before: x = self.after_norm(x) + mid_layer_out = self.mid_layer_norm(mid_layer_out) + + # (T, N, d_model) -> (N, T, d_model) + mid_layer_out = mid_layer_out.permute(1, 0, 2) logits = self.encoder_output_layer(x) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return logits, lengths + return logits, lengths, mid_layer_out class ConformerEncoderLayer(nn.Module): @@ -281,11 +296,28 @@ class ConformerEncoder(nn.TransformerEncoder): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + self, + encoder_layer: nn.Module, + num_layers: int, + mid_layer: int, + norm: nn.Module = None, ) -> None: + """ + Args: + encoder_layer: + Type of the encoder layer. + num_layers: + Number of encoder layers. + mid_layer: + Also return the output of this layer in `forward()`. + norm: + If not None, the output of the last layer is processed by `norm`. + """ super(ConformerEncoder, self).__init__( encoder_layer=encoder_layer, num_layers=num_layers, norm=norm ) + assert 0 <= mid_layer < num_layers, (mid_layer, num_layers) + self.mid_layer = mid_layer def forward( self, @@ -293,7 +325,7 @@ class ConformerEncoder(nn.TransformerEncoder): pos_emb: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + ) -> Tuple[Tensor, Tensor]: r"""Pass the input through the encoder layers in turn. Args: @@ -312,18 +344,20 @@ class ConformerEncoder(nn.TransformerEncoder): """ output = src - for mod in self.layers: + for i, mod in enumerate(self.layers): output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) + if i == self.mid_layer: + mid_layer_output = output if self.norm is not None: output = self.norm(output) - return output + return output, mid_layer_output class RelPositionalEncoding(torch.nn.Module): diff --git a/egs/librispeech/ASR/transducer_stateless_aux_kl/subsampling.py b/egs/librispeech/ASR/transducer_stateless_aux_kl/subsampling.py new file mode 120000 index 000000000..6fee09e58 --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless_aux_kl/subsampling.py @@ -0,0 +1 @@ +../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/transducer_stateless_aux_kl/test_conformer.py b/egs/librispeech/ASR/transducer_stateless_aux_kl/test_conformer.py new file mode 100644 index 000000000..59698b24c --- /dev/null +++ b/egs/librispeech/ASR/transducer_stateless_aux_kl/test_conformer.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_stateless_aux_kl/test_conformer.py +""" + +import torch +from conformer import Conformer + + +def test_conformer(): + output_dim = 1024 + d_model = 512 + conformer = Conformer( + num_features=80, + output_dim=output_dim, + subsampling_factor=4, + d_model=d_model, + nhead=8, + dim_feedforward=2048, + num_encoder_layers=12, + mid_layer=6, + ) + N = 3 + T = 100 + C = 80 + x = torch.randn(N, T, C) + x_lens = torch.tensor([50, 100, 80]) + logits, logit_lens, mid_layer_out = conformer(x, x_lens) + + expected_T = ((T - 1) // 2 - 1) // 2 + assert logits.shape == (N, expected_T, output_dim) + assert logit_lens.max().item() == expected_T + assert mid_layer_out.shape == (N, expected_T, d_model) + print(logits.shape) + print(logit_lens) + + +def main(): + test_conformer() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py b/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py index 720151ea0..ae91e76fd 100755 --- a/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py +++ b/egs/librispeech/ASR/transducer_stateless_aux_kl/train.py @@ -233,6 +233,9 @@ def get_params() -> AttributeDict: "nhead": 8, "dim_feedforward": 2048, "num_encoder_layers": 12, + # We us the output from mid_layer as the input of the + # auxiliary branch + "mid_layer": 6, "vgg_frontend": False, # parameters for Noam "warm_step": 80000, # For the 100h subset, use 8k @@ -253,6 +256,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: nhead=params.nhead, dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, + mid_layer=params.mid_layer, vgg_frontend=params.vgg_frontend, ) return encoder