mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove some more unused code; rename BasicNorm->BiasNorm, Zipformer->Zipformer2
This commit is contained in:
parent
3424b60d8f
commit
f59da65d82
@ -325,7 +325,7 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BasicNormFunction(torch.autograd.Function):
|
class BiasNormFunction(torch.autograd.Function):
|
||||||
# This computes:
|
# This computes:
|
||||||
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
||||||
# return (x - bias) * scales
|
# return (x - bias) * scales
|
||||||
@ -368,7 +368,7 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BasicNorm(torch.nn.Module):
|
class BiasNorm(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||||
LayerNorm. The observation this is based on, is that Transformer-type
|
LayerNorm. The observation this is based on, is that Transformer-type
|
||||||
@ -378,9 +378,10 @@ class BasicNorm(torch.nn.Module):
|
|||||||
on the other (useful) features. Presumably the weight and bias of the
|
on the other (useful) features. Presumably the weight and bias of the
|
||||||
LayerNorm are required to allow it to do this.
|
LayerNorm are required to allow it to do this.
|
||||||
|
|
||||||
So the idea is to introduce this large constant value as an explicit
|
Instead, we give the BiasNorm a trainable bias that it can use when
|
||||||
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
computing the scale for normalization. We also give it a (scalar)
|
||||||
doesn't have to do this trick. We make the "eps" learnable.
|
trainable scale on the output.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_channels: the number of channels, e.g. 512.
|
num_channels: the number of channels, e.g. 512.
|
||||||
@ -397,7 +398,6 @@ class BasicNorm(torch.nn.Module):
|
|||||||
than the input of this module to be required to be stored for the
|
than the input of this module to be required to be stored for the
|
||||||
backprop.
|
backprop.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
@ -407,7 +407,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
log_scale_max: float = 1.5,
|
log_scale_max: float = 1.5,
|
||||||
store_output_for_backprop: bool = False
|
store_output_for_backprop: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
super(BasicNorm, self).__init__()
|
super(BiasNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
||||||
@ -438,248 +438,12 @@ class BasicNorm(torch.nn.Module):
|
|||||||
max=float(self.log_scale_max),
|
max=float(self.log_scale_max),
|
||||||
training=self.training)
|
training=self.training)
|
||||||
|
|
||||||
return BasicNormFunction.apply(x, self.bias, log_scale,
|
return BiasNormFunction.apply(x, self.bias, log_scale,
|
||||||
self.channel_dim,
|
self.channel_dim,
|
||||||
self.store_output_for_backprop)
|
self.store_output_for_backprop)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PositiveConv1d(nn.Conv1d):
|
|
||||||
"""
|
|
||||||
A modified form of nn.Conv1d where the weight parameters are constrained
|
|
||||||
to be positive and there is no bias.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
|
|
||||||
**kwargs):
|
|
||||||
super().__init__(*args, **kwargs, bias=False)
|
|
||||||
self.min = min
|
|
||||||
self.max = max
|
|
||||||
|
|
||||||
# initialize weight to all positive values.
|
|
||||||
with torch.no_grad():
|
|
||||||
self.weight[:] = 1.0 / self.weight[0][0].numel()
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
Forward function. Input and returned tensor have shape:
|
|
||||||
(N, C, H)
|
|
||||||
i.e. (batch_size, num_channels, height)
|
|
||||||
"""
|
|
||||||
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
|
|
||||||
training=self.training)
|
|
||||||
# make absolutely sure there are no negative values. For parameter-averaging-related
|
|
||||||
# reasons, we prefer to also use limit_param_value to make sure the weights stay
|
|
||||||
# positive.
|
|
||||||
weight = weight.abs()
|
|
||||||
|
|
||||||
if self.padding_mode != 'zeros':
|
|
||||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
|
||||||
weight, self.bias, self.stride,
|
|
||||||
_single(0), self.dilation, self.groups)
|
|
||||||
return F.conv1d(input, weight, self.bias, self.stride,
|
|
||||||
self.padding, self.dilation, self.groups)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConvNorm1d(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
This is like BasicNorm except the denominator is summed over time using
|
|
||||||
convolution with positive weights.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_channels: the number of channels, e.g. 512.
|
|
||||||
eps: the initial "epsilon" that we add as ballast in:
|
|
||||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
|
||||||
Note: our epsilon is actually large, but we keep the name
|
|
||||||
to indicate the connection with conventional LayerNorm.
|
|
||||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
|
||||||
at the initial value.
|
|
||||||
eps_min: float
|
|
||||||
eps_max: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_channels: int,
|
|
||||||
eps: float = 0.25,
|
|
||||||
learn_eps: bool = True,
|
|
||||||
eps_min: float = -3.0,
|
|
||||||
eps_max: float = 3.0,
|
|
||||||
conv_min: float = 0.001,
|
|
||||||
conv_max: float = 1.0,
|
|
||||||
kernel_size: int = 15,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = num_channels
|
|
||||||
if learn_eps:
|
|
||||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
|
||||||
else:
|
|
||||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
|
||||||
self.eps_min = eps_min
|
|
||||||
self.eps_max = eps_max
|
|
||||||
pad = kernel_size // 2
|
|
||||||
# it has bias=False.
|
|
||||||
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad,
|
|
||||||
min=conv_min, max=conv_max)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
|
||||||
"""
|
|
||||||
x shape: (N, C, T)
|
|
||||||
|
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
|
||||||
(N, T), contains True in masked positions.
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert x.ndim == 3 and x.shape[1] == self.num_channels
|
|
||||||
eps = self.eps
|
|
||||||
if self.training and random.random() < 0.25:
|
|
||||||
# with probability 0.25, in training mode, clamp eps between the min
|
|
||||||
# and max; this will encourage it to learn parameters within the
|
|
||||||
# allowed range by making parameters that are outside the allowed
|
|
||||||
# range noisy.
|
|
||||||
|
|
||||||
# gradients to allow the parameter to get back into the allowed
|
|
||||||
# region if it happens to exit it.
|
|
||||||
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
|
|
||||||
|
|
||||||
# sqnorms: (N, 1, T)
|
|
||||||
sqnorms = (
|
|
||||||
torch.mean(x ** 2, dim=1, keepdim=True)
|
|
||||||
)
|
|
||||||
# 'counts' is a mechanism to correct for edge effects.
|
|
||||||
counts = torch.ones_like(sqnorms)
|
|
||||||
if src_key_padding_mask is not None:
|
|
||||||
counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
|
|
||||||
sqnorms = sqnorms * counts
|
|
||||||
sqnorms = self.conv(sqnorms)
|
|
||||||
# the clamping is to avoid division by zero for padding frames.
|
|
||||||
counts = torch.clamp(self.conv(counts), min=0.01)
|
|
||||||
# scales: (N, 1, T)
|
|
||||||
scales = (sqnorms / counts + eps.exp()) ** -0.5 #
|
|
||||||
return x * scales
|
|
||||||
|
|
||||||
|
|
||||||
class PositiveConv2d(nn.Conv2d):
|
|
||||||
"""
|
|
||||||
A modified form of nn.Conv2d where the weight parameters are constrained
|
|
||||||
to be positive and there is no bias.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
|
|
||||||
**kwargs):
|
|
||||||
super().__init__(*args, **kwargs, bias=False)
|
|
||||||
self.min = min
|
|
||||||
self.max = max
|
|
||||||
|
|
||||||
# initialize weight to all positive values.
|
|
||||||
with torch.no_grad():
|
|
||||||
self.weight[:] = 1.0 / self.weight[0][0].numel()
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
Forward function. Input and returned tensor have shape:
|
|
||||||
(N, C, H, W)
|
|
||||||
i.e. (batch_size, num_channels, height, width)
|
|
||||||
"""
|
|
||||||
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
|
|
||||||
training=self.training)
|
|
||||||
# make absolutely sure there are no negative values. For parameter-averaging-related
|
|
||||||
# reasons, we prefer to also use limit_param_value to make sure the weights stay
|
|
||||||
# positive.
|
|
||||||
weight = weight.abs()
|
|
||||||
|
|
||||||
if self.padding_mode != 'zeros':
|
|
||||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
|
||||||
weight, self.bias, self.stride,
|
|
||||||
_pair(0), self.dilation, self.groups)
|
|
||||||
return F.conv2d(input, weight, self.bias, self.stride,
|
|
||||||
self.padding, self.dilation, self.groups)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvNorm2d(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
This is like BasicNorm except the denominator is summed over time using
|
|
||||||
convolution with positive weights.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_channels: the number of channels, e.g. 512.
|
|
||||||
eps: the initial "epsilon" that we add as ballast in:
|
|
||||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
|
||||||
Note: our epsilon is actually large, but we keep the name
|
|
||||||
to indicate the connection with conventional LayerNorm.
|
|
||||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
|
||||||
at the initial value.
|
|
||||||
eps_min: float
|
|
||||||
eps_max: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_channels: int,
|
|
||||||
eps: float = 0.25,
|
|
||||||
learn_eps: bool = True,
|
|
||||||
eps_min: float = -3.0,
|
|
||||||
eps_max: float = 3.0,
|
|
||||||
conv_min: float = 0.001,
|
|
||||||
conv_max: float = 1.0,
|
|
||||||
kernel_size: Tuple[int, int] = (3, 3),
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = num_channels
|
|
||||||
if learn_eps:
|
|
||||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
|
||||||
else:
|
|
||||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
|
||||||
self.eps_min = eps_min
|
|
||||||
self.eps_max = eps_max
|
|
||||||
pad = (kernel_size[0] // 2, kernel_size[1] // 2)
|
|
||||||
# it has bias=False.
|
|
||||||
self.conv = PositiveConv2d(1, 1, kernel_size=kernel_size, padding=pad,
|
|
||||||
min=conv_min, max=conv_max)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
"""
|
|
||||||
x shape: (N, C, H, W)
|
|
||||||
"""
|
|
||||||
assert x.ndim == 4 and x.shape[1] == self.num_channels
|
|
||||||
eps = self.eps
|
|
||||||
if self.training and random.random() < 0.25:
|
|
||||||
# with probability 0.25, in training mode, clamp eps between the min
|
|
||||||
# and max; this will encourage it to learn parameters within the
|
|
||||||
# allowed range by making parameters that are outside the allowed
|
|
||||||
# range noisy.
|
|
||||||
|
|
||||||
# gradients to allow the parameter to get back into the allowed
|
|
||||||
# region if it happens to exit it.
|
|
||||||
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
|
|
||||||
|
|
||||||
# sqnorms: (N, 1, H, W)
|
|
||||||
sqnorms = (
|
|
||||||
torch.mean(x ** 2, dim=1, keepdim=True)
|
|
||||||
)
|
|
||||||
# 'counts' is a mechanism to correct for edge effects.
|
|
||||||
# TODO: key-padding mask
|
|
||||||
|
|
||||||
counts = torch.ones_like(sqnorms)
|
|
||||||
#if src_key_padding_mask is not None:
|
|
||||||
# counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
|
|
||||||
#sqnorms = sqnorms * counts
|
|
||||||
sqnorms = self.conv(sqnorms)
|
|
||||||
# the clamping is to avoid division by zero for padding frames.
|
|
||||||
counts = torch.clamp(self.conv(counts), min=0.01)
|
|
||||||
# scales: (N, 1, H, W)
|
|
||||||
scales = (sqnorms / counts + eps.exp()) ** -0.5
|
|
||||||
return x * scales
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def ScaledLinear(*args,
|
def ScaledLinear(*args,
|
||||||
initial_scale: float = 1.0,
|
initial_scale: float = 1.0,
|
||||||
**kwargs ) -> nn.Linear:
|
**kwargs ) -> nn.Linear:
|
||||||
|
|||||||
@ -60,7 +60,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer2
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -536,7 +536,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
# TODO: We can add an option to switch between Zipformer and Transformer
|
# TODO: We can add an option to switch between Zipformer and Transformer
|
||||||
def to_int_tuple(s: str):
|
def to_int_tuple(s: str):
|
||||||
return tuple(map(int, s.split(',')))
|
return tuple(map(int, s.split(',')))
|
||||||
encoder = Zipformer(
|
encoder = Zipformer2(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
downsampling_factor=to_int_tuple(params.downsampling_factor),
|
downsampling_factor=to_int_tuple(params.downsampling_factor),
|
||||||
|
|||||||
@ -26,9 +26,7 @@ import random
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import (
|
from scaling import (
|
||||||
Balancer,
|
Balancer,
|
||||||
BasicNorm,
|
BiasNorm,
|
||||||
ConvNorm1d,
|
|
||||||
ConvNorm2d,
|
|
||||||
Dropout2,
|
Dropout2,
|
||||||
Dropout3,
|
Dropout3,
|
||||||
SwooshL,
|
SwooshL,
|
||||||
@ -53,7 +51,7 @@ from icefall.utils import make_pad_mask
|
|||||||
from icefall.dist import get_rank
|
from icefall.dist import get_rank
|
||||||
|
|
||||||
|
|
||||||
class Zipformer(EncoderInterface):
|
class Zipformer2(EncoderInterface):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
||||||
@ -127,7 +125,7 @@ class Zipformer(EncoderInterface):
|
|||||||
chunk_size: Tuple[int] = [-1],
|
chunk_size: Tuple[int] = [-1],
|
||||||
left_context_frames: Tuple[int] = [-1],
|
left_context_frames: Tuple[int] = [-1],
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer, self).__init__()
|
super(Zipformer2, self).__init__()
|
||||||
|
|
||||||
if dropout is None:
|
if dropout is None:
|
||||||
dropout = ScheduledFloat((0.0, 0.3),
|
dropout = ScheduledFloat((0.0, 0.3),
|
||||||
@ -185,13 +183,13 @@ class Zipformer(EncoderInterface):
|
|||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
|
|
||||||
|
|
||||||
# each one will be ZipformerEncoder or DownsampledZipformerEncoder
|
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
||||||
encoders = []
|
encoders = []
|
||||||
|
|
||||||
num_encoders = len(downsampling_factor)
|
num_encoders = len(downsampling_factor)
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
|
|
||||||
encoder_layer = ZipformerEncoderLayer(
|
encoder_layer = Zipformer2EncoderLayer(
|
||||||
embed_dim=encoder_dim[i],
|
embed_dim=encoder_dim[i],
|
||||||
pos_dim=pos_dim,
|
pos_dim=pos_dim,
|
||||||
num_heads=num_heads[i],
|
num_heads=num_heads[i],
|
||||||
@ -206,7 +204,7 @@ class Zipformer(EncoderInterface):
|
|||||||
|
|
||||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||||
# layer learn something. Then we start to warm up the other encoders.
|
# layer learn something. Then we start to warm up the other encoders.
|
||||||
encoder = ZipformerEncoder(
|
encoder = Zipformer2Encoder(
|
||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers[i],
|
num_encoder_layers[i],
|
||||||
pos_dim=pos_dim,
|
pos_dim=pos_dim,
|
||||||
@ -218,7 +216,7 @@ class Zipformer(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if downsampling_factor[i] != 1:
|
if downsampling_factor[i] != 1:
|
||||||
encoder = DownsampledZipformerEncoder(
|
encoder = DownsampledZipformer2Encoder(
|
||||||
encoder,
|
encoder,
|
||||||
dim=encoder_dim[i],
|
dim=encoder_dim[i],
|
||||||
downsample=downsampling_factor[i],
|
downsample=downsampling_factor[i],
|
||||||
@ -492,7 +490,7 @@ def _balancer_schedule(min_prob: float):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ZipformerEncoderLayer(nn.Module):
|
class Zipformer2EncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
embed_dim: the number of expected features in the input (required).
|
embed_dim: the number of expected features in the input (required).
|
||||||
@ -502,7 +500,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> pos_emb = torch.rand(32, 19, 512)
|
>>> pos_emb = torch.rand(32, 19, 512)
|
||||||
>>> out = encoder_layer(src, pos_emb)
|
>>> out = encoder_layer(src, pos_emb)
|
||||||
@ -530,7 +528,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
||||||
bypass_max: FloatLike = 1.0,
|
bypass_max: FloatLike = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ZipformerEncoderLayer, self).__init__()
|
super(Zipformer2EncoderLayer, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# probability of skipping the entire layer.
|
# probability of skipping the entire layer.
|
||||||
@ -578,7 +576,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||||
|
|
||||||
self.norm = BasicNorm(embed_dim)
|
self.norm = BiasNorm(embed_dim)
|
||||||
|
|
||||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||||
|
|
||||||
@ -760,17 +758,17 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
return src, attn_weights
|
return src, attn_weights
|
||||||
|
|
||||||
class ZipformerEncoder(nn.Module):
|
class Zipformer2Encoder(nn.Module):
|
||||||
r"""ZipformerEncoder is a stack of N encoder layers
|
r"""Zipformer2Encoder is a stack of N encoder layers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_layer: an instance of the ZipformerEncoderLayer() class (required).
|
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
pos_dim: the dimension for the relative positional encoding
|
pos_dim: the dimension for the relative positional encoding
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||||
>>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
|
>>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
|
||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> out = zipformer_encoder(src)
|
>>> out = zipformer_encoder(src)
|
||||||
"""
|
"""
|
||||||
@ -856,9 +854,9 @@ class ZipformerEncoder(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class DownsampledZipformerEncoder(nn.Module):
|
class DownsampledZipformer2Encoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
|
DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
after convolutional downsampling, and then upsampled again at the output, and combined
|
||||||
with the origin input, so that the output has the same shape as the input.
|
with the origin input, so that the output has the same shape as the input.
|
||||||
"""
|
"""
|
||||||
@ -867,7 +865,7 @@ class DownsampledZipformerEncoder(nn.Module):
|
|||||||
dim: int,
|
dim: int,
|
||||||
downsample: int,
|
downsample: int,
|
||||||
dropout: FloatLike):
|
dropout: FloatLike):
|
||||||
super(DownsampledZipformerEncoder, self).__init__()
|
super(DownsampledZipformer2Encoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsample = SimpleDownsample(dim,
|
self.downsample = SimpleDownsample(dim,
|
||||||
downsample, dropout)
|
downsample, dropout)
|
||||||
@ -1031,79 +1029,6 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SmallConvolutionModule(nn.Module):
|
|
||||||
"""Part of Zipformer model: a small version of the Convolution module that uses a small kernel.
|
|
||||||
Inspired by convnext (i.e. have the depthwise conv first.)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channels (int): The number of channels of conv layers.
|
|
||||||
kernel_size (int): Kernerl size of conv layers.
|
|
||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, channels: int,
|
|
||||||
hidden_dim: int = 128,
|
|
||||||
kernel_size: int = 5,
|
|
||||||
causal: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
|
|
||||||
channels=channels,
|
|
||||||
kernel_size=kernel_size) if causal else nn.Conv1d(
|
|
||||||
in_channels=channels,
|
|
||||||
out_channels=channels,
|
|
||||||
groups=channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
padding=kernel_size // 2)
|
|
||||||
|
|
||||||
self.linear1 = nn.Linear(
|
|
||||||
channels, hidden_dim)
|
|
||||||
|
|
||||||
# balancer and activation as tuned for ConvolutionModule.
|
|
||||||
|
|
||||||
self.balancer = Balancer(
|
|
||||||
hidden_dim, channel_dim=-1,
|
|
||||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
|
||||||
max_positive=1.0,
|
|
||||||
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
|
||||||
max_abs=10.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.activation = SwooshR()
|
|
||||||
|
|
||||||
self.linear2 = ScaledLinear(hidden_dim, channels,
|
|
||||||
initial_scale=0.05)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x: Tensor,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
"""Compute convolution module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input tensor (#time, batch, channels).
|
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
|
||||||
(batch, #time), contains bool in masked positions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (#time, batch, channels).
|
|
||||||
"""
|
|
||||||
# exchange the temporal dimension and the feature dimension
|
|
||||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
|
||||||
|
|
||||||
if src_key_padding_mask is not None:
|
|
||||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
|
||||||
x = self.depthwise_conv(x) # (batch, channels, time)
|
|
||||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
|
||||||
x = self.linear1(x) # (time, batch, hidden_dim)
|
|
||||||
x = self.balancer(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
x = self.linear2(x) # (time, batch, channels)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
||||||
@ -1502,7 +1427,7 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedforwardModule(nn.Module):
|
class FeedforwardModule(nn.Module):
|
||||||
"""Feedforward module in Zipformer model.
|
"""Feedforward module in Zipformer2 model.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
@ -1645,7 +1570,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Zipformer model.
|
"""ConvolutionModule in Zipformer2 model.
|
||||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
|
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1983,7 +1908,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
||||||
# getting large, there is an unnecessary degree of freedom.
|
# getting large, there is an unnecessary degree of freedom.
|
||||||
self.out_norm = BasicNorm(out_channels)
|
self.out_norm = BiasNorm(out_channels)
|
||||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||||
|
|
||||||
|
|
||||||
@ -2018,143 +1943,6 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class AttentionCombine(nn.Module):
|
|
||||||
"""
|
|
||||||
This module combines a list of Tensors, all with the same shape, to
|
|
||||||
produce a single output of that same shape which, in training time,
|
|
||||||
is a random combination of all the inputs; but which in test time
|
|
||||||
will be just the last input.
|
|
||||||
|
|
||||||
All but the last input will have a linear transform before we
|
|
||||||
randomly combine them; these linear transforms will be initialized
|
|
||||||
to the identity transform.
|
|
||||||
|
|
||||||
The idea is that the list of Tensors will be a list of outputs of multiple
|
|
||||||
zipformer layers. This has a similar effect as iterated loss. (See:
|
|
||||||
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
|
||||||
NETWORKS).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_channels: int,
|
|
||||||
num_inputs: int,
|
|
||||||
random_prob: float = 0.25,
|
|
||||||
single_prob: float = 0.333,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
num_channels:
|
|
||||||
the number of channels
|
|
||||||
num_inputs:
|
|
||||||
The number of tensor inputs, which equals the number of layers'
|
|
||||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
|
||||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
|
||||||
random_prob:
|
|
||||||
the probability with which we apply a nontrivial mask, in training
|
|
||||||
mode.
|
|
||||||
single_prob:
|
|
||||||
the probability with which we mask to allow just a single
|
|
||||||
module's output (in training)
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.random_prob = random_prob
|
|
||||||
self.single_prob = single_prob
|
|
||||||
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
|
|
||||||
num_inputs))
|
|
||||||
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
|
||||||
|
|
||||||
self.name = None # will be set from training code
|
|
||||||
assert 0 <= random_prob <= 1, random_prob
|
|
||||||
assert 0 <= single_prob <= 1, single_prob
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, inputs: List[Tensor]) -> Tensor:
|
|
||||||
"""Forward function.
|
|
||||||
Args:
|
|
||||||
inputs:
|
|
||||||
A list of Tensor, e.g. from various layers of a transformer.
|
|
||||||
All must be the same shape, of (*, num_channels)
|
|
||||||
Returns:
|
|
||||||
A Tensor of shape (*, num_channels). In test mode
|
|
||||||
this is just the final input.
|
|
||||||
"""
|
|
||||||
num_inputs = self.weight.shape[1]
|
|
||||||
assert len(inputs) == num_inputs
|
|
||||||
|
|
||||||
# Shape of weights: (*, num_inputs)
|
|
||||||
num_channels = inputs[0].shape[-1]
|
|
||||||
num_frames = inputs[0].numel() // num_channels
|
|
||||||
|
|
||||||
ndim = inputs[0].ndim
|
|
||||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
|
||||||
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
|
|
||||||
(num_frames, num_channels, num_inputs)
|
|
||||||
)
|
|
||||||
|
|
||||||
scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
|
|
||||||
|
|
||||||
if random.random() < 0.002:
|
|
||||||
logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}")
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
# random masking..
|
|
||||||
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
|
||||||
size=(num_frames,), device=scores.device).unsqueeze(1)
|
|
||||||
# mask will have rows like: [ False, False, False, True, True, .. ]
|
|
||||||
arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand(
|
|
||||||
num_frames, num_inputs)
|
|
||||||
mask = arange >= mask_start
|
|
||||||
|
|
||||||
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
|
|
||||||
device=scores.device) < self.single_prob,
|
|
||||||
mask_start < num_inputs)
|
|
||||||
single_prob_mask = torch.logical_and(apply_single_prob,
|
|
||||||
arange < mask_start - 1)
|
|
||||||
|
|
||||||
mask = torch.logical_or(mask,
|
|
||||||
single_prob_mask)
|
|
||||||
|
|
||||||
scores = scores.masked_fill(mask, float('-inf'))
|
|
||||||
|
|
||||||
if self.training and random.random() < 0.1:
|
|
||||||
scores = penalize_abs_values_gt(scores,
|
|
||||||
limit=10.0,
|
|
||||||
penalty=1.0e-04,
|
|
||||||
name=self.name)
|
|
||||||
|
|
||||||
weights = scores.softmax(dim=1)
|
|
||||||
|
|
||||||
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
|
|
||||||
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
|
|
||||||
# ans: (*, num_channels)
|
|
||||||
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# for testing only...
|
|
||||||
print("Weights = ", weights.reshape(num_frames, num_inputs))
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_random_combine():
|
|
||||||
print("_test_random_combine()")
|
|
||||||
num_inputs = 3
|
|
||||||
num_channels = 50
|
|
||||||
m = AttentionCombine(
|
|
||||||
num_channels=num_channels,
|
|
||||||
num_inputs=num_inputs,
|
|
||||||
random_prob=0.5,
|
|
||||||
single_prob=0.0)
|
|
||||||
|
|
||||||
|
|
||||||
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
|
||||||
|
|
||||||
y = m(x)
|
|
||||||
assert y.shape == x[0].shape
|
|
||||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
|
||||||
|
|
||||||
|
|
||||||
def _test_zipformer_main(causal: bool = False):
|
def _test_zipformer_main(causal: bool = False):
|
||||||
@ -2164,7 +1952,7 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
|
|
||||||
c = Zipformer(
|
c = Zipformer2(
|
||||||
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
|
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
chunk_size=(4,) if causal else (-1,),
|
chunk_size=(4,) if causal else (-1,),
|
||||||
@ -2191,6 +1979,5 @@ if __name__ == "__main__":
|
|||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
_test_random_combine()
|
|
||||||
_test_zipformer_main(False)
|
_test_zipformer_main(False)
|
||||||
_test_zipformer_main(True)
|
_test_zipformer_main(True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user