mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-11 10:04:21 +00:00
fix relative positional encoding in streaming decoding for compution saving
This commit is contained in:
parent
fc54a99a56
commit
3aacf75652
@ -717,9 +717,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor, context: int = 0) -> None:
|
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
x_size_1 = x.size(1) + context
|
x_size_1 = x.size(1) + left_context
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
@ -756,13 +756,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
context: int = 0,
|
left_context: int = 0,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
context (int): left context (in frames) used during streaming decoding.
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
|
|
||||||
@ -771,14 +771,14 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x, context)
|
self.extend_pe(x, left_context)
|
||||||
x_size_1 = x.size(1) + context
|
x_size_1 = x.size(1) + left_context
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
- x_size_1
|
- x_size_1
|
||||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
+ x_size_1,
|
+ x.size(1),
|
||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
@ -931,7 +931,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
(batch_size, num_heads, time1, n) = x.shape
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
|
|
||||||
time2 = time1 + left_context
|
time2 = time1 + left_context
|
||||||
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
|
assert (
|
||||||
|
n == left_context + 2 * time1 - 1
|
||||||
|
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||||
|
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
batch_stride = x.stride(0)
|
batch_stride = x.stride(0)
|
||||||
|
@ -681,9 +681,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor, context: int = 0) -> None:
|
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
x_size_1 = x.size(1) + context
|
x_size_1 = x.size(1) + left_context
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
@ -718,13 +718,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, context: int = 0
|
self, x: torch.Tensor, left_context: int = 0
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
context (int): left context (in frames) used during streaming decoding.
|
left_context (int): left context (in frames) used during streaming decoding.
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
it MUST be 0.
|
it MUST be 0.
|
||||||
Returns:
|
Returns:
|
||||||
@ -732,15 +732,15 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x, context)
|
self.extend_pe(x, left_context)
|
||||||
x = x * self.xscale
|
x = x * self.xscale
|
||||||
x_size_1 = x.size(1) + context
|
x_size_1 = x.size(1) + left_context
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
- x_size_1
|
- x_size_1
|
||||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
+ x_size_1,
|
+ x.size(1),
|
||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
@ -888,7 +888,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
(batch_size, num_heads, time1, n) = x.shape
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
time2 = time1 + left_context
|
time2 = time1 + left_context
|
||||||
|
|
||||||
assert n == 2 * time2 - 1, f"{n} == 2 * {time2} - 1"
|
assert (
|
||||||
|
n == left_context + 2 * time1 - 1
|
||||||
|
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||||
|
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
batch_stride = x.stride(0)
|
batch_stride = x.stride(0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user