mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-11 01:54:20 +00:00
fix relative positional encoding in streaming decoding for compution saving
This commit is contained in:
parent
fc54a99a56
commit
3aacf75652
@ -717,9 +717,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor, context: int = 0) -> None:
|
||||
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
x_size_1 = x.size(1) + context
|
||||
x_size_1 = x.size(1) + left_context
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
@ -756,13 +756,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: int = 0,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
context (int): left context (in frames) used during streaming decoding.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
@ -771,14 +771,14 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x, context)
|
||||
x_size_1 = x.size(1) + context
|
||||
self.extend_pe(x, left_context)
|
||||
x_size_1 = x.size(1) + left_context
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x_size_1
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x_size_1,
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
@ -931,7 +931,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
|
||||
time2 = time1 + left_context
|
||||
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
|
@ -681,9 +681,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor, context: int = 0) -> None:
|
||||
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
x_size_1 = x.size(1) + context
|
||||
x_size_1 = x.size(1) + left_context
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
@ -718,13 +718,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, context: int = 0
|
||||
self, x: torch.Tensor, left_context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
context (int): left context (in frames) used during streaming decoding.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
Returns:
|
||||
@ -732,15 +732,15 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x, context)
|
||||
self.extend_pe(x, left_context)
|
||||
x = x * self.xscale
|
||||
x_size_1 = x.size(1) + context
|
||||
x_size_1 = x.size(1) + left_context
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x_size_1
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x_size_1,
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
@ -888,7 +888,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
time2 = time1 + left_context
|
||||
|
||||
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user