From 3aacf75652ce6b3caeffa578578bf3c39c2b70b6 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 6 Jun 2022 06:46:40 +0800 Subject: [PATCH] fix relative positional encoding in streaming decoding for compution saving --- .../pruned_transducer_stateless2/conformer.py | 18 ++++++++++-------- .../ASR/transducer_stateless/conformer.py | 18 ++++++++++-------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index f72c63036..a0872b934 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -717,9 +717,9 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor, context: int = 0) -> None: + def extend_pe(self, x: Tensor, left_context: int = 0) -> None: """Reset the positional encodings.""" - x_size_1 = x.size(1) + context + x_size_1 = x.size(1) + left_context if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 @@ -756,13 +756,13 @@ class RelPositionalEncoding(torch.nn.Module): def forward( self, x: torch.Tensor, - context: int = 0, + left_context: int = 0, ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). - context (int): left context (in frames) used during streaming decoding. + left_context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. @@ -771,14 +771,14 @@ class RelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x, context) - x_size_1 = x.size(1) + context + self.extend_pe(x, left_context) + x_size_1 = x.size(1) + left_context pos_emb = self.pe[ :, self.pe.size(1) // 2 - x_size_1 + 1 : self.pe.size(1) // 2 # noqa E203 - + x_size_1, + + x.size(1), ] return self.dropout(x), self.dropout(pos_emb) @@ -931,7 +931,9 @@ class RelPositionMultiheadAttention(nn.Module): (batch_size, num_heads, time1, n) = x.shape time2 = time1 + left_context - assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1" + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 3d6b089c1..ec33565ba 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -681,9 +681,9 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = None self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor, context: int = 0) -> None: + def extend_pe(self, x: Tensor, left_context: int = 0) -> None: """Reset the positional encodings.""" - x_size_1 = x.size(1) + context + x_size_1 = x.size(1) + left_context if self.pe is not None: # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 @@ -718,13 +718,13 @@ class RelPositionalEncoding(torch.nn.Module): self.pe = pe.to(device=x.device, dtype=x.dtype) def forward( - self, x: torch.Tensor, context: int = 0 + self, x: torch.Tensor, left_context: int = 0 ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: x (torch.Tensor): Input tensor (batch, time, `*`). - context (int): left context (in frames) used during streaming decoding. + left_context (int): left context (in frames) used during streaming decoding. this is used only in real streaming decoding, in other circumstances, it MUST be 0. Returns: @@ -732,15 +732,15 @@ class RelPositionalEncoding(torch.nn.Module): torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). """ - self.extend_pe(x, context) + self.extend_pe(x, left_context) x = x * self.xscale - x_size_1 = x.size(1) + context + x_size_1 = x.size(1) + left_context pos_emb = self.pe[ :, self.pe.size(1) // 2 - x_size_1 + 1 : self.pe.size(1) // 2 # noqa E203 - + x_size_1, + + x.size(1), ] return self.dropout(x), self.dropout(pos_emb) @@ -888,7 +888,9 @@ class RelPositionMultiheadAttention(nn.Module): (batch_size, num_heads, time1, n) = x.shape time2 = time1 + left_context - assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1" + assert ( + n == left_context + 2 * time1 - 1 + ), f"{n} == {left_context} + 2 * {time1} - 1" # Note: TorchScript requires explicit arg for stride() batch_stride = x.stride(0)