mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
417 lines
13 KiB
Python
417 lines
13 KiB
Python
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import math
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from encoder_interface import EncoderInterface
|
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
|
|
|
from icefall.utils import make_pad_mask
|
|
|
|
|
|
class Transformer(EncoderInterface):
|
|
def __init__(
|
|
self,
|
|
num_features: int,
|
|
output_dim: int,
|
|
subsampling_factor: int = 4,
|
|
d_model: int = 256,
|
|
nhead: int = 4,
|
|
dim_feedforward: int = 2048,
|
|
num_encoder_layers: int = 12,
|
|
dropout: float = 0.1,
|
|
normalize_before: bool = True,
|
|
vgg_frontend: bool = False,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
num_features:
|
|
The input dimension of the model.
|
|
output_dim:
|
|
The output dimension of the model.
|
|
subsampling_factor:
|
|
Number of output frames is num_in_frames // subsampling_factor.
|
|
Currently, subsampling_factor MUST be 4.
|
|
d_model:
|
|
Attention dimension.
|
|
nhead:
|
|
Number of heads in multi-head attention.
|
|
Must satisfy d_model // nhead == 0.
|
|
dim_feedforward:
|
|
The output dimension of the feedforward layers in encoder.
|
|
num_encoder_layers:
|
|
Number of encoder layers.
|
|
dropout:
|
|
Dropout in encoder.
|
|
normalize_before:
|
|
If True, use pre-layer norm; False to use post-layer norm.
|
|
vgg_frontend:
|
|
True to use vgg style frontend for subsampling.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.num_features = num_features
|
|
self.output_dim = output_dim
|
|
self.subsampling_factor = subsampling_factor
|
|
if subsampling_factor != 4:
|
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
|
|
|
# self.encoder_embed converts the input of shape (N, T, num_features)
|
|
# to the shape (N, T//subsampling_factor, d_model).
|
|
# That is, it does two things simultaneously:
|
|
# (1) subsampling: T -> T//subsampling_factor
|
|
# (2) embedding: num_features -> d_model
|
|
if vgg_frontend:
|
|
self.encoder_embed = VggSubsampling(num_features, d_model)
|
|
else:
|
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
|
|
|
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
|
|
|
encoder_layer = TransformerEncoderLayer(
|
|
d_model=d_model,
|
|
nhead=nhead,
|
|
dim_feedforward=dim_feedforward,
|
|
dropout=dropout,
|
|
normalize_before=normalize_before,
|
|
)
|
|
|
|
if normalize_before:
|
|
encoder_norm = nn.LayerNorm(d_model)
|
|
else:
|
|
encoder_norm = None
|
|
|
|
self.encoder = nn.TransformerEncoder(
|
|
encoder_layer=encoder_layer,
|
|
num_layers=num_encoder_layers,
|
|
norm=encoder_norm,
|
|
)
|
|
|
|
# TODO(fangjun): remove dropout
|
|
self.encoder_output_layer = nn.Sequential(
|
|
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
|
|
)
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
x:
|
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
|
x_lens:
|
|
A tensor of shape (batch_size,) containing the number of frames in
|
|
`x` before padding.
|
|
Returns:
|
|
Return a tuple containing 2 tensors:
|
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
|
of frames in `logits` before padding.
|
|
"""
|
|
x = self.encoder_embed(x)
|
|
x = self.encoder_pos(x)
|
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
|
|
# Caution: We assume the subsampling factor is 4!
|
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
|
assert x.size(0) == lengths.max().item()
|
|
|
|
mask = make_pad_mask(lengths)
|
|
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
|
|
|
|
logits = self.encoder_output_layer(x)
|
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
|
|
|
return logits, lengths
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
"""
|
|
Modified from torch.nn.TransformerEncoderLayer.
|
|
Add support of normalize_before,
|
|
i.e., use layer_norm before the first block.
|
|
|
|
Args:
|
|
d_model:
|
|
the number of expected features in the input (required).
|
|
nhead:
|
|
the number of heads in the multiheadattention models (required).
|
|
dim_feedforward:
|
|
the dimension of the feedforward network model (default=2048).
|
|
dropout:
|
|
the dropout value (default=0.1).
|
|
activation:
|
|
the activation function of intermediate layer, relu or
|
|
gelu (default=relu).
|
|
normalize_before:
|
|
whether to use layer_norm before the first block.
|
|
|
|
Examples::
|
|
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
|
>>> src = torch.rand(10, 32, 512)
|
|
>>> out = encoder_layer(src)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
nhead: int,
|
|
dim_feedforward: int = 2048,
|
|
dropout: float = 0.1,
|
|
activation: str = "relu",
|
|
normalize_before: bool = True,
|
|
) -> None:
|
|
super(TransformerEncoderLayer, self).__init__()
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
|
# Implementation of Feedforward model
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
self.activation = _get_activation_fn(activation)
|
|
|
|
self.normalize_before = normalize_before
|
|
|
|
def __setstate__(self, state):
|
|
if "activation" not in state:
|
|
state["activation"] = nn.functional.relu
|
|
super(TransformerEncoderLayer, self).__setstate__(state)
|
|
|
|
def forward(
|
|
self,
|
|
src: torch.Tensor,
|
|
src_mask: Optional[torch.Tensor] = None,
|
|
src_key_padding_mask: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Pass the input through the encoder layer.
|
|
|
|
Args:
|
|
src: the sequence to the encoder layer (required).
|
|
src_mask: the mask for the src sequence (optional).
|
|
src_key_padding_mask: the mask for the src keys per batch (optional)
|
|
|
|
Shape:
|
|
src: (S, N, E).
|
|
src_mask: (S, S).
|
|
src_key_padding_mask: (N, S).
|
|
S is the source sequence length, T is the target sequence length,
|
|
N is the batch size, E is the feature number
|
|
"""
|
|
residual = src
|
|
if self.normalize_before:
|
|
src = self.norm1(src)
|
|
src2 = self.self_attn(
|
|
src,
|
|
src,
|
|
src,
|
|
attn_mask=src_mask,
|
|
key_padding_mask=src_key_padding_mask,
|
|
)[0]
|
|
src = residual + self.dropout1(src2)
|
|
if not self.normalize_before:
|
|
src = self.norm1(src)
|
|
|
|
residual = src
|
|
if self.normalize_before:
|
|
src = self.norm2(src)
|
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
src = residual + self.dropout2(src2)
|
|
if not self.normalize_before:
|
|
src = self.norm2(src)
|
|
return src
|
|
|
|
|
|
def _get_activation_fn(activation: str):
|
|
if activation == "relu":
|
|
return nn.functional.relu
|
|
elif activation == "gelu":
|
|
return nn.functional.gelu
|
|
|
|
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
"""This class implements the positional encoding
|
|
proposed in the following paper:
|
|
|
|
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
|
|
|
|
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
|
|
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
|
|
|
|
Note::
|
|
|
|
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
|
|
= exp(-1* 2i / d_model * log(100000))
|
|
= exp(2i * -(log(10000) / d_model))
|
|
"""
|
|
|
|
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
|
|
"""
|
|
Args:
|
|
d_model:
|
|
Embedding dimension.
|
|
dropout:
|
|
Dropout probability to be applied to the output of this module.
|
|
"""
|
|
super().__init__()
|
|
self.d_model = d_model
|
|
self.xscale = math.sqrt(self.d_model)
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
# not doing: self.pe = None because of errors thrown by torchscript
|
|
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
|
|
|
|
def extend_pe(self, x: torch.Tensor) -> None:
|
|
"""Extend the time t in the positional encoding if required.
|
|
|
|
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
|
|
is (N, T, d_model). If T > T1, then we change the shape of self.pe
|
|
to (N, T, d_model). Otherwise, nothing is done.
|
|
|
|
Args:
|
|
x:
|
|
It is a tensor of shape (N, T, C).
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
if self.pe is not None:
|
|
if self.pe.size(1) >= x.size(1):
|
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
|
return
|
|
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
|
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
|
div_term = torch.exp(
|
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
|
* -(math.log(10000.0) / self.d_model)
|
|
)
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
# Now pe is of shape (1, T, d_model), where T is x.size(1)
|
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Add positional encoding.
|
|
|
|
Args:
|
|
x:
|
|
Its shape is (N, T, C)
|
|
|
|
Returns:
|
|
Return a tensor of shape (N, T, C)
|
|
"""
|
|
self.extend_pe(x)
|
|
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
|
return self.dropout(x)
|
|
|
|
|
|
class Noam(object):
|
|
"""
|
|
Implements Noam optimizer.
|
|
|
|
Proposed in
|
|
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
|
|
|
Modified from
|
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
|
|
|
|
Args:
|
|
params:
|
|
iterable of parameters to optimize or dicts defining parameter groups
|
|
model_size:
|
|
attention dimension of the transformer model
|
|
factor:
|
|
learning rate factor
|
|
warm_step:
|
|
warmup steps
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
model_size: int = 256,
|
|
factor: float = 10.0,
|
|
warm_step: int = 25000,
|
|
weight_decay=0,
|
|
) -> None:
|
|
"""Construct an Noam object."""
|
|
self.optimizer = torch.optim.Adam(
|
|
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
|
)
|
|
self._step = 0
|
|
self.warmup = warm_step
|
|
self.factor = factor
|
|
self.model_size = model_size
|
|
self._rate = 0
|
|
|
|
@property
|
|
def param_groups(self):
|
|
"""Return param_groups."""
|
|
return self.optimizer.param_groups
|
|
|
|
def step(self):
|
|
"""Update parameters and rate."""
|
|
self._step += 1
|
|
rate = self.rate()
|
|
for p in self.optimizer.param_groups:
|
|
p["lr"] = rate
|
|
self._rate = rate
|
|
self.optimizer.step()
|
|
|
|
def rate(self, step=None):
|
|
"""Implement `lrate` above."""
|
|
if step is None:
|
|
step = self._step
|
|
return (
|
|
self.factor
|
|
* self.model_size ** (-0.5)
|
|
* min(step ** (-0.5), step * self.warmup ** (-1.5))
|
|
)
|
|
|
|
def zero_grad(self):
|
|
"""Reset gradient."""
|
|
self.optimizer.zero_grad()
|
|
|
|
def state_dict(self):
|
|
"""Return state_dict."""
|
|
return {
|
|
"_step": self._step,
|
|
"warmup": self.warmup,
|
|
"factor": self.factor,
|
|
"model_size": self.model_size,
|
|
"_rate": self._rate,
|
|
"optimizer": self.optimizer.state_dict(),
|
|
}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
"""Load state_dict."""
|
|
for key, value in state_dict.items():
|
|
if key == "optimizer":
|
|
self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
else:
|
|
setattr(self, key, value)
|