mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
|
|
"""A streamable transformer."""
|
|
import typing as tp
|
|
from typing import Any, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor, nn
|
|
|
|
|
|
def create_sin_embedding(positions: Tensor, dim: int, max_period: float = 10000):
|
|
"""Create time embedding for the given positions, target dimension `dim`."""
|
|
# We aim for BTC format
|
|
assert dim % 2 == 0
|
|
half_dim = dim // 2
|
|
adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
|
|
phase = positions / (max_period ** (adim / (half_dim - 1)))
|
|
return torch.cat(
|
|
[
|
|
torch.cos(phase),
|
|
torch.sin(phase),
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
|
|
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
|
def forward(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore
|
|
if self.norm_first:
|
|
sa_input = self.norm1(x)
|
|
x = x + self._sa_block(sa_input, x_past, past_context)
|
|
x = x + self._ff_block(self.norm2(x))
|
|
else:
|
|
sa_input = x
|
|
x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
|
|
x = self.norm2(x + self._ff_block(x))
|
|
|
|
return x, sa_input
|
|
|
|
# self-attention block
|
|
def _sa_block(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore
|
|
_, T, _ = x.shape
|
|
_, H, _ = x_past.shape
|
|
|
|
queries = x
|
|
keys = torch.cat([x_past, x], dim=1)
|
|
values = keys
|
|
|
|
queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
|
|
keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
|
|
delta = queries_pos - keys_pos
|
|
valid_access = (delta >= 0) & (delta <= past_context)
|
|
x = self.self_attn(
|
|
queries, keys, values, attn_mask=~valid_access, need_weights=False
|
|
)[0]
|
|
return self.dropout1(x)
|
|
|
|
|
|
class StreamingTransformerEncoder(nn.Module):
|
|
"""TransformerEncoder with streaming support.
|
|
|
|
Args:
|
|
dim (int): dimension of the data.
|
|
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
|
|
num_heads (int): number of heads.
|
|
num_layers (int): number of layers.
|
|
max_period (float): maxium period of cosines in the positional embedding.
|
|
past_context (int or None): receptive field for the causal mask, infinite if None.
|
|
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
|
|
norm_in (bool): normalize the input.
|
|
dropout (float): dropout probability.
|
|
**kwargs: See `nn.TransformerEncoderLayer`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
hidden_scale: float = 4.0,
|
|
num_heads: int = 8,
|
|
num_layers: int = 5,
|
|
max_period: float = 10000,
|
|
past_context: int = 1000,
|
|
gelu: bool = True,
|
|
norm_in: bool = True,
|
|
dropout: float = 0.0,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
assert dim % num_heads == 0
|
|
hidden_dim = int(dim * hidden_scale)
|
|
|
|
self.max_period = max_period
|
|
self.past_context = past_context
|
|
activation: Any = F.gelu if gelu else F.relu
|
|
|
|
self.norm_in: nn.Module
|
|
if norm_in:
|
|
self.norm_in = nn.LayerNorm(dim)
|
|
else:
|
|
self.norm_in = nn.Identity()
|
|
|
|
self.layers = nn.ModuleList()
|
|
for idx in range(num_layers):
|
|
self.layers.append(
|
|
StreamingTransformerEncoderLayer(
|
|
dim,
|
|
num_heads,
|
|
hidden_dim,
|
|
activation=activation,
|
|
batch_first=True,
|
|
dropout=dropout,
|
|
**kwargs
|
|
)
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
states: Optional[List[Tensor]] = None,
|
|
offset: Union[int, Tensor] = 0,
|
|
):
|
|
B, T, C = x.shape
|
|
if states is None:
|
|
states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
|
|
|
|
positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
|
|
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
|
|
|
|
new_state: List[Tensor] = []
|
|
x = self.norm_in(x)
|
|
x = x + pos_emb
|
|
|
|
for layer_state, layer in zip(states, self.layers):
|
|
x, new_layer_state = layer(x, layer_state, self.past_context)
|
|
new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
|
|
new_state.append(new_layer_state[:, -self.past_context :, :])
|
|
return x, new_state, offset + T
|