From d74e2e8e076b2c8a9384aea941560437cabd3f59 Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Thu, 12 May 2022 21:27:32 +0800 Subject: [PATCH] add relative position encoding --- .../emformer.py | 153 ++++++++++++++++++ .../encoder_interface.py | 1 + .../scaling.py | 1 + .../test_emformer.py | 17 ++ 4 files changed, 172 insertions(+) create mode 100644 egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer_stateless/encoder_interface.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer_stateless/scaling.py create mode 100644 egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py new file mode 100644 index 000000000..249167041 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -0,0 +1,153 @@ +# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# It is modified based on https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py. # noqa + +import math +import warnings +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) + +from icefall.utils import make_pad_mask + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py # noqa + + Suppose: + i -> position of query, + j -> position of key(value), + we use positive relative position embedding when key(value) is to the + left of query(i.e., i > j) and negative embedding otherwise. + + Args: + d_model: Embedding dimension. + dropout: Dropout rate. + max_len: Maximum input length. + """ + + def __init__( + self, d_model: int, dropout: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout) + self.pe = None + self.pos_len = max_len + self.neg_len = max_len + self.gen_pe_positive() + self.gen_pe_negative() + + def gen_pe_positive(self) -> None: + """Generate the positive positional encodings.""" + pe_positive = torch.zeros(self.pos_len, self.d_model) + position_positive = torch.arange( + 0, self.pos_len, dtype=torch.float32 + ).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe_positive[:, 0::2] = torch.sin(position_positive * div_term) + pe_positive[:, 1::2] = torch.cos(position_positive * div_term) + # Reserve the order of positive indices and concat both positive and + # negative indices. This is used to support the shifting trick + # as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" # noqa + self.pe_positive = torch.flip(pe_positive, [0]) + + def gen_pe_negative(self) -> None: + """Generate the negative positional encodings.""" + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i torch.Tensor: + """Get positional encoding given positive length and negative length.""" + if self.pe_positive.dtype != dtype or str( + self.pe_positive.device + ) != str(device): + self.pe_positive = self.pe_positive.to(dtype=dtype, device=device) + if self.pe_negative.dtype != dtype or str( + self.pe_negative.device + ) != str(device): + self.pe_negative = self.pe_negative.to(dtype=dtype, device=device) + pe = torch.cat( + [ + self.pe_positive[self.pos_len - pos_len :], + self.pe_negative[1:neg_len], + ], + dim=0, + ) + return pe + + def forward( + self, + x: torch.Tensor, + pos_len: int, + neg_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Scale input x and get positional encoding. + Args: + x (torch.Tensor): Input tensor (`*`). + + Returns: + torch.Tensor: + Encoded tensor of shape (`*`). + torch.Tensor: + Position embedding of shape (pos_len + neg_len - 1, `*`). + """ + x = x * self.xscale + if pos_len > self.pos_len: + self.pos_len = pos_len + self.gen_pe_positive() + if neg_len > self.neg_len: + self.neg_len = neg_len + self.gen_pe_negative() + pos_emb = self.get_pe(pos_len, neg_len, x.device, x.dtype) + return self.dropout(x), self.dropout(pos_emb) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/encoder_interface.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/scaling.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/scaling.py new file mode 120000 index 000000000..09d802cc4 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py new file mode 100644 index 000000000..528931a54 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/test_emformer.py @@ -0,0 +1,17 @@ +import torch + + +def test_rel_positional_encoding(): + from emformer import RelPositionalEncoding + + D = 256 + pos_enc = RelPositionalEncoding(D, dropout=0.1) + pos_len = 100 + neg_len = 100 + x = torch.randn(2, D) + x, pos_emb = pos_enc(x, pos_len, neg_len) + assert pos_emb.shape == (pos_len + neg_len - 1, D) + + +if __name__ == "__main__": + test_rel_positional_encoding()