fix relative positional encoding in streaming decoding for compution saving

This commit is contained in:
pkufool 2022-06-06 06:46:40 +08:00
parent fc54a99a56
commit 3aacf75652
2 changed files with 20 additions and 16 deletions

View File

@ -717,9 +717,9 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor, context: int = 0) -> None: def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
x_size_1 = x.size(1) + context x_size_1 = x.size(1) + left_context
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
@ -756,13 +756,13 @@ class RelPositionalEncoding(torch.nn.Module):
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
context: int = 0, left_context: int = 0,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Add positional encoding. """Add positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (batch, time, `*`). x (torch.Tensor): Input tensor (batch, time, `*`).
context (int): left context (in frames) used during streaming decoding. left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
@ -771,14 +771,14 @@ class RelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
""" """
self.extend_pe(x, context) self.extend_pe(x, left_context)
x_size_1 = x.size(1) + context x_size_1 = x.size(1) + left_context
pos_emb = self.pe[ pos_emb = self.pe[
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
- x_size_1 - x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203 + 1 : self.pe.size(1) // 2 # noqa E203
+ x_size_1, + x.size(1),
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -931,7 +931,9 @@ class RelPositionMultiheadAttention(nn.Module):
(batch_size, num_heads, time1, n) = x.shape (batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context time2 = time1 + left_context
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1" assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride() # Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0) batch_stride = x.stride(0)

View File

@ -681,9 +681,9 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len)) self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor, context: int = 0) -> None: def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings.""" """Reset the positional encodings."""
x_size_1 = x.size(1) + context x_size_1 = x.size(1) + left_context
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
@ -718,13 +718,13 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = pe.to(device=x.device, dtype=x.dtype) self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward( def forward(
self, x: torch.Tensor, context: int = 0 self, x: torch.Tensor, left_context: int = 0
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Add positional encoding. """Add positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (batch, time, `*`). x (torch.Tensor): Input tensor (batch, time, `*`).
context (int): left context (in frames) used during streaming decoding. left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances, this is used only in real streaming decoding, in other circumstances,
it MUST be 0. it MUST be 0.
Returns: Returns:
@ -732,15 +732,15 @@ class RelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
""" """
self.extend_pe(x, context) self.extend_pe(x, left_context)
x = x * self.xscale x = x * self.xscale
x_size_1 = x.size(1) + context x_size_1 = x.size(1) + left_context
pos_emb = self.pe[ pos_emb = self.pe[
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
- x_size_1 - x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203 + 1 : self.pe.size(1) // 2 # noqa E203
+ x_size_1, + x.size(1),
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(x), self.dropout(pos_emb)
@ -888,7 +888,9 @@ class RelPositionMultiheadAttention(nn.Module):
(batch_size, num_heads, time1, n) = x.shape (batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context time2 = time1 + left_context
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1" assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride() # Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0) batch_stride = x.stride(0)