zr_jin e8b6b920c0
A LibriTTS recipe on both ASR & Neural Codec Tasks (#1746)
* added ASR & CODEC recipes for LibriTTS corpus
2024-10-21 11:30:14 +08:00

142 lines
4.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE
"""A streamable transformer."""
import typing as tp
from typing import Any, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
def create_sin_embedding(positions: Tensor, dim: int, max_period: float = 10000):
"""Create time embedding for the given positions, target dimension `dim`."""
# We aim for BTC format
assert dim % 2 == 0
half_dim = dim // 2
adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
phase = positions / (max_period ** (adim / (half_dim - 1)))
return torch.cat(
[
torch.cos(phase),
torch.sin(phase),
],
dim=-1,
)
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
def forward(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore
if self.norm_first:
sa_input = self.norm1(x)
x = x + self._sa_block(sa_input, x_past, past_context)
x = x + self._ff_block(self.norm2(x))
else:
sa_input = x
x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
x = self.norm2(x + self._ff_block(x))
return x, sa_input
# self-attention block
def _sa_block(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore
_, T, _ = x.shape
_, H, _ = x_past.shape
queries = x
keys = torch.cat([x_past, x], dim=1)
values = keys
queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1)
keys_pos = torch.arange(T + H, device=x.device).view(1, -1)
delta = queries_pos - keys_pos
valid_access = (delta >= 0) & (delta <= past_context)
x = self.self_attn(
queries, keys, values, attn_mask=~valid_access, need_weights=False
)[0]
return self.dropout1(x)
class StreamingTransformerEncoder(nn.Module):
"""TransformerEncoder with streaming support.
Args:
dim (int): dimension of the data.
hidden_scale (int): intermediate dimension of FF module is this times the dimension.
num_heads (int): number of heads.
num_layers (int): number of layers.
max_period (float): maxium period of cosines in the positional embedding.
past_context (int or None): receptive field for the causal mask, infinite if None.
gelu (bool): if true uses GeLUs, otherwise use ReLUs.
norm_in (bool): normalize the input.
dropout (float): dropout probability.
**kwargs: See `nn.TransformerEncoderLayer`.
"""
def __init__(
self,
dim,
hidden_scale: float = 4.0,
num_heads: int = 8,
num_layers: int = 5,
max_period: float = 10000,
past_context: int = 1000,
gelu: bool = True,
norm_in: bool = True,
dropout: float = 0.0,
**kwargs
):
super().__init__()
assert dim % num_heads == 0
hidden_dim = int(dim * hidden_scale)
self.max_period = max_period
self.past_context = past_context
activation: Any = F.gelu if gelu else F.relu
self.norm_in: nn.Module
if norm_in:
self.norm_in = nn.LayerNorm(dim)
else:
self.norm_in = nn.Identity()
self.layers = nn.ModuleList()
for idx in range(num_layers):
self.layers.append(
StreamingTransformerEncoderLayer(
dim,
num_heads,
hidden_dim,
activation=activation,
batch_first=True,
dropout=dropout,
**kwargs
)
)
def forward(
self,
x: Tensor,
states: Optional[List[Tensor]] = None,
offset: Union[int, Tensor] = 0,
):
B, T, C = x.shape
if states is None:
states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))]
positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period)
new_state: List[Tensor] = []
x = self.norm_in(x)
x = x + pos_emb
for layer_state, layer in zip(states, self.layers):
x, new_layer_state = layer(x, layer_state, self.past_context)
new_layer_state = torch.cat([layer_state, new_layer_state], dim=1)
new_state.append(new_layer_state[:, -self.past_context :, :])
return x, new_state, offset + T