fix relative positional encoding in streaming decoding for compution saving

This commit is contained in:
pkufool 2022-06-06 06:46:40 +08:00
parent fc54a99a56
commit 3aacf75652
2 changed files with 20 additions and 16 deletions

View File

@ -717,9 +717,9 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor, context: int = 0) -> None:
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = x.size(1) + context
x_size_1 = x.size(1) + left_context
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
@ -756,13 +756,13 @@ class RelPositionalEncoding(torch.nn.Module):
def forward(
self,
x: torch.Tensor,
context: int = 0,
left_context: int = 0,
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
context (int): left context (in frames) used during streaming decoding.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
@ -771,14 +771,14 @@ class RelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x, context)
x_size_1 = x.size(1) + context
self.extend_pe(x, left_context)
x_size_1 = x.size(1) + left_context
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x_size_1,
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
@ -931,7 +931,9 @@ class RelPositionMultiheadAttention(nn.Module):
(batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)

View File

@ -681,9 +681,9 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor, context: int = 0) -> None:
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = x.size(1) + context
x_size_1 = x.size(1) + left_context
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
@ -718,13 +718,13 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(
self, x: torch.Tensor, context: int = 0
self, x: torch.Tensor, left_context: int = 0
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
context (int): left context (in frames) used during streaming decoding.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
@ -732,15 +732,15 @@ class RelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x, context)
self.extend_pe(x, left_context)
x = x * self.xscale
x_size_1 = x.size(1) + context
x_size_1 = x.size(1) + left_context
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x_size_1,
+ x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
@ -888,7 +888,9 @@ class RelPositionMultiheadAttention(nn.Module):
(batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)