mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove some more unused code; rename BasicNorm->BiasNorm, Zipformer->Zipformer2
This commit is contained in:
parent
3424b60d8f
commit
f59da65d82
@ -325,7 +325,7 @@ class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
|
||||
class BasicNormFunction(torch.autograd.Function):
|
||||
class BiasNormFunction(torch.autograd.Function):
|
||||
# This computes:
|
||||
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
||||
# return (x - bias) * scales
|
||||
@ -368,7 +368,7 @@ class BasicNormFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
|
||||
class BasicNorm(torch.nn.Module):
|
||||
class BiasNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
LayerNorm. The observation this is based on, is that Transformer-type
|
||||
@ -378,9 +378,10 @@ class BasicNorm(torch.nn.Module):
|
||||
on the other (useful) features. Presumably the weight and bias of the
|
||||
LayerNorm are required to allow it to do this.
|
||||
|
||||
So the idea is to introduce this large constant value as an explicit
|
||||
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
||||
doesn't have to do this trick. We make the "eps" learnable.
|
||||
Instead, we give the BiasNorm a trainable bias that it can use when
|
||||
computing the scale for normalization. We also give it a (scalar)
|
||||
trainable scale on the output.
|
||||
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
@ -397,7 +398,6 @@ class BasicNorm(torch.nn.Module):
|
||||
than the input of this module to be required to be stored for the
|
||||
backprop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
@ -407,7 +407,7 @@ class BasicNorm(torch.nn.Module):
|
||||
log_scale_max: float = 1.5,
|
||||
store_output_for_backprop: bool = False
|
||||
) -> None:
|
||||
super(BasicNorm, self).__init__()
|
||||
super(BiasNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
||||
@ -438,248 +438,12 @@ class BasicNorm(torch.nn.Module):
|
||||
max=float(self.log_scale_max),
|
||||
training=self.training)
|
||||
|
||||
return BasicNormFunction.apply(x, self.bias, log_scale,
|
||||
return BiasNormFunction.apply(x, self.bias, log_scale,
|
||||
self.channel_dim,
|
||||
self.store_output_for_backprop)
|
||||
|
||||
|
||||
|
||||
class PositiveConv1d(nn.Conv1d):
|
||||
"""
|
||||
A modified form of nn.Conv1d where the weight parameters are constrained
|
||||
to be positive and there is no bias.
|
||||
"""
|
||||
def __init__(
|
||||
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs, bias=False)
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
# initialize weight to all positive values.
|
||||
with torch.no_grad():
|
||||
self.weight[:] = 1.0 / self.weight[0][0].numel()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""
|
||||
Forward function. Input and returned tensor have shape:
|
||||
(N, C, H)
|
||||
i.e. (batch_size, num_channels, height)
|
||||
"""
|
||||
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
|
||||
training=self.training)
|
||||
# make absolutely sure there are no negative values. For parameter-averaging-related
|
||||
# reasons, we prefer to also use limit_param_value to make sure the weights stay
|
||||
# positive.
|
||||
weight = weight.abs()
|
||||
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.bias, self.stride,
|
||||
_single(0), self.dilation, self.groups)
|
||||
return F.conv1d(input, weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
|
||||
class ConvNorm1d(torch.nn.Module):
|
||||
"""
|
||||
This is like BasicNorm except the denominator is summed over time using
|
||||
convolution with positive weights.
|
||||
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
eps_min: float
|
||||
eps_max: float
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
eps_min: float = -3.0,
|
||||
eps_max: float = 3.0,
|
||||
conv_min: float = 0.001,
|
||||
conv_max: float = 1.0,
|
||||
kernel_size: int = 15,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
self.eps_min = eps_min
|
||||
self.eps_max = eps_max
|
||||
pad = kernel_size // 2
|
||||
# it has bias=False.
|
||||
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad,
|
||||
min=conv_min, max=conv_max)
|
||||
|
||||
|
||||
def forward(self, x: Tensor,
|
||||
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
||||
"""
|
||||
x shape: (N, C, T)
|
||||
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
||||
(N, T), contains True in masked positions.
|
||||
|
||||
"""
|
||||
assert x.ndim == 3 and x.shape[1] == self.num_channels
|
||||
eps = self.eps
|
||||
if self.training and random.random() < 0.25:
|
||||
# with probability 0.25, in training mode, clamp eps between the min
|
||||
# and max; this will encourage it to learn parameters within the
|
||||
# allowed range by making parameters that are outside the allowed
|
||||
# range noisy.
|
||||
|
||||
# gradients to allow the parameter to get back into the allowed
|
||||
# region if it happens to exit it.
|
||||
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
|
||||
|
||||
# sqnorms: (N, 1, T)
|
||||
sqnorms = (
|
||||
torch.mean(x ** 2, dim=1, keepdim=True)
|
||||
)
|
||||
# 'counts' is a mechanism to correct for edge effects.
|
||||
counts = torch.ones_like(sqnorms)
|
||||
if src_key_padding_mask is not None:
|
||||
counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
|
||||
sqnorms = sqnorms * counts
|
||||
sqnorms = self.conv(sqnorms)
|
||||
# the clamping is to avoid division by zero for padding frames.
|
||||
counts = torch.clamp(self.conv(counts), min=0.01)
|
||||
# scales: (N, 1, T)
|
||||
scales = (sqnorms / counts + eps.exp()) ** -0.5 #
|
||||
return x * scales
|
||||
|
||||
|
||||
class PositiveConv2d(nn.Conv2d):
|
||||
"""
|
||||
A modified form of nn.Conv2d where the weight parameters are constrained
|
||||
to be positive and there is no bias.
|
||||
"""
|
||||
def __init__(
|
||||
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs, bias=False)
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
||||
# initialize weight to all positive values.
|
||||
with torch.no_grad():
|
||||
self.weight[:] = 1.0 / self.weight[0][0].numel()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
"""
|
||||
Forward function. Input and returned tensor have shape:
|
||||
(N, C, H, W)
|
||||
i.e. (batch_size, num_channels, height, width)
|
||||
"""
|
||||
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
|
||||
training=self.training)
|
||||
# make absolutely sure there are no negative values. For parameter-averaging-related
|
||||
# reasons, we prefer to also use limit_param_value to make sure the weights stay
|
||||
# positive.
|
||||
weight = weight.abs()
|
||||
|
||||
if self.padding_mode != 'zeros':
|
||||
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
|
||||
weight, self.bias, self.stride,
|
||||
_pair(0), self.dilation, self.groups)
|
||||
return F.conv2d(input, weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
class ConvNorm2d(torch.nn.Module):
|
||||
"""
|
||||
This is like BasicNorm except the denominator is summed over time using
|
||||
convolution with positive weights.
|
||||
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
eps: the initial "epsilon" that we add as ballast in:
|
||||
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
||||
Note: our epsilon is actually large, but we keep the name
|
||||
to indicate the connection with conventional LayerNorm.
|
||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||
at the initial value.
|
||||
eps_min: float
|
||||
eps_max: float
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
eps: float = 0.25,
|
||||
learn_eps: bool = True,
|
||||
eps_min: float = -3.0,
|
||||
eps_max: float = 3.0,
|
||||
conv_min: float = 0.001,
|
||||
conv_max: float = 1.0,
|
||||
kernel_size: Tuple[int, int] = (3, 3),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
if learn_eps:
|
||||
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
||||
else:
|
||||
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
||||
self.eps_min = eps_min
|
||||
self.eps_max = eps_max
|
||||
pad = (kernel_size[0] // 2, kernel_size[1] // 2)
|
||||
# it has bias=False.
|
||||
self.conv = PositiveConv2d(1, 1, kernel_size=kernel_size, padding=pad,
|
||||
min=conv_min, max=conv_max)
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
x shape: (N, C, H, W)
|
||||
"""
|
||||
assert x.ndim == 4 and x.shape[1] == self.num_channels
|
||||
eps = self.eps
|
||||
if self.training and random.random() < 0.25:
|
||||
# with probability 0.25, in training mode, clamp eps between the min
|
||||
# and max; this will encourage it to learn parameters within the
|
||||
# allowed range by making parameters that are outside the allowed
|
||||
# range noisy.
|
||||
|
||||
# gradients to allow the parameter to get back into the allowed
|
||||
# region if it happens to exit it.
|
||||
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
|
||||
|
||||
# sqnorms: (N, 1, H, W)
|
||||
sqnorms = (
|
||||
torch.mean(x ** 2, dim=1, keepdim=True)
|
||||
)
|
||||
# 'counts' is a mechanism to correct for edge effects.
|
||||
# TODO: key-padding mask
|
||||
|
||||
counts = torch.ones_like(sqnorms)
|
||||
#if src_key_padding_mask is not None:
|
||||
# counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
|
||||
#sqnorms = sqnorms * counts
|
||||
sqnorms = self.conv(sqnorms)
|
||||
# the clamping is to avoid division by zero for padding frames.
|
||||
counts = torch.clamp(self.conv(counts), min=0.01)
|
||||
# scales: (N, 1, H, W)
|
||||
scales = (sqnorms / counts + eps.exp()) ** -0.5
|
||||
return x * scales
|
||||
|
||||
|
||||
|
||||
|
||||
def ScaledLinear(*args,
|
||||
initial_scale: float = 1.0,
|
||||
**kwargs ) -> nn.Linear:
|
||||
|
||||
@ -60,7 +60,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from zipformer import Zipformer
|
||||
from zipformer import Zipformer2
|
||||
from scaling import ScheduledFloat
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -536,7 +536,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Zipformer and Transformer
|
||||
def to_int_tuple(s: str):
|
||||
return tuple(map(int, s.split(',')))
|
||||
encoder = Zipformer(
|
||||
encoder = Zipformer2(
|
||||
num_features=params.feature_dim,
|
||||
output_downsampling_factor=2,
|
||||
downsampling_factor=to_int_tuple(params.downsampling_factor),
|
||||
|
||||
@ -26,9 +26,7 @@ import random
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
Balancer,
|
||||
BasicNorm,
|
||||
ConvNorm1d,
|
||||
ConvNorm2d,
|
||||
BiasNorm,
|
||||
Dropout2,
|
||||
Dropout3,
|
||||
SwooshL,
|
||||
@ -53,7 +51,7 @@ from icefall.utils import make_pad_mask
|
||||
from icefall.dist import get_rank
|
||||
|
||||
|
||||
class Zipformer(EncoderInterface):
|
||||
class Zipformer2(EncoderInterface):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@ -127,7 +125,7 @@ class Zipformer(EncoderInterface):
|
||||
chunk_size: Tuple[int] = [-1],
|
||||
left_context_frames: Tuple[int] = [-1],
|
||||
) -> None:
|
||||
super(Zipformer, self).__init__()
|
||||
super(Zipformer2, self).__init__()
|
||||
|
||||
if dropout is None:
|
||||
dropout = ScheduledFloat((0.0, 0.3),
|
||||
@ -185,13 +183,13 @@ class Zipformer(EncoderInterface):
|
||||
dropout=dropout)
|
||||
|
||||
|
||||
# each one will be ZipformerEncoder or DownsampledZipformerEncoder
|
||||
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
||||
encoders = []
|
||||
|
||||
num_encoders = len(downsampling_factor)
|
||||
for i in range(num_encoders):
|
||||
|
||||
encoder_layer = ZipformerEncoderLayer(
|
||||
encoder_layer = Zipformer2EncoderLayer(
|
||||
embed_dim=encoder_dim[i],
|
||||
pos_dim=pos_dim,
|
||||
num_heads=num_heads[i],
|
||||
@ -206,7 +204,7 @@ class Zipformer(EncoderInterface):
|
||||
|
||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||
# layer learn something. Then we start to warm up the other encoders.
|
||||
encoder = ZipformerEncoder(
|
||||
encoder = Zipformer2Encoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers[i],
|
||||
pos_dim=pos_dim,
|
||||
@ -218,7 +216,7 @@ class Zipformer(EncoderInterface):
|
||||
)
|
||||
|
||||
if downsampling_factor[i] != 1:
|
||||
encoder = DownsampledZipformerEncoder(
|
||||
encoder = DownsampledZipformer2Encoder(
|
||||
encoder,
|
||||
dim=encoder_dim[i],
|
||||
downsample=downsampling_factor[i],
|
||||
@ -492,7 +490,7 @@ def _balancer_schedule(min_prob: float):
|
||||
|
||||
|
||||
|
||||
class ZipformerEncoderLayer(nn.Module):
|
||||
class Zipformer2EncoderLayer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
embed_dim: the number of expected features in the input (required).
|
||||
@ -502,7 +500,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8)
|
||||
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = encoder_layer(src, pos_emb)
|
||||
@ -530,7 +528,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
|
||||
bypass_max: FloatLike = 1.0,
|
||||
) -> None:
|
||||
super(ZipformerEncoderLayer, self).__init__()
|
||||
super(Zipformer2EncoderLayer, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# probability of skipping the entire layer.
|
||||
@ -578,7 +576,7 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||
|
||||
self.norm = BasicNorm(embed_dim)
|
||||
self.norm = BiasNorm(embed_dim)
|
||||
|
||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||
|
||||
@ -760,17 +758,17 @@ class ZipformerEncoderLayer(nn.Module):
|
||||
|
||||
return src, attn_weights
|
||||
|
||||
class ZipformerEncoder(nn.Module):
|
||||
r"""ZipformerEncoder is a stack of N encoder layers
|
||||
class Zipformer2Encoder(nn.Module):
|
||||
r"""Zipformer2Encoder is a stack of N encoder layers
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the ZipformerEncoderLayer() class (required).
|
||||
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
pos_dim: the dimension for the relative positional encoding
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8)
|
||||
>>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||
>>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = zipformer_encoder(src)
|
||||
"""
|
||||
@ -856,9 +854,9 @@ class ZipformerEncoder(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class DownsampledZipformerEncoder(nn.Module):
|
||||
class DownsampledZipformer2Encoder(nn.Module):
|
||||
r"""
|
||||
DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
||||
with the origin input, so that the output has the same shape as the input.
|
||||
"""
|
||||
@ -867,7 +865,7 @@ class DownsampledZipformerEncoder(nn.Module):
|
||||
dim: int,
|
||||
downsample: int,
|
||||
dropout: FloatLike):
|
||||
super(DownsampledZipformerEncoder, self).__init__()
|
||||
super(DownsampledZipformer2Encoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsample = SimpleDownsample(dim,
|
||||
downsample, dropout)
|
||||
@ -1031,79 +1029,6 @@ class SimpleCombiner(torch.nn.Module):
|
||||
|
||||
|
||||
|
||||
class SmallConvolutionModule(nn.Module):
|
||||
"""Part of Zipformer model: a small version of the Convolution module that uses a small kernel.
|
||||
Inspired by convnext (i.e. have the depthwise conv first.)
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int,
|
||||
hidden_dim: int = 128,
|
||||
kernel_size: int = 5,
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
|
||||
channels=channels,
|
||||
kernel_size=kernel_size) if causal else nn.Conv1d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
groups=channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
|
||||
self.linear1 = nn.Linear(
|
||||
channels, hidden_dim)
|
||||
|
||||
# balancer and activation as tuned for ConvolutionModule.
|
||||
|
||||
self.balancer = Balancer(
|
||||
hidden_dim, channel_dim=-1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
max_positive=1.0,
|
||||
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
||||
max_abs=10.0,
|
||||
)
|
||||
|
||||
self.activation = SwooshR()
|
||||
|
||||
self.linear2 = ScaledLinear(hidden_dim, channels,
|
||||
initial_scale=0.05)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
||||
(batch, #time), contains bool in masked positions.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
x = self.depthwise_conv(x) # (batch, channels, time)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
x = self.linear1(x) # (time, batch, hidden_dim)
|
||||
x = self.balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.linear2(x) # (time, batch, channels)
|
||||
return x
|
||||
|
||||
|
||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
"""
|
||||
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
||||
@ -1502,7 +1427,7 @@ class SelfAttention(nn.Module):
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Zipformer model.
|
||||
"""Feedforward module in Zipformer2 model.
|
||||
"""
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
@ -1645,7 +1570,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Zipformer model.
|
||||
"""ConvolutionModule in Zipformer2 model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
|
||||
|
||||
Args:
|
||||
@ -1983,7 +1908,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
||||
# getting large, there is an unnecessary degree of freedom.
|
||||
self.out_norm = BasicNorm(out_channels)
|
||||
self.out_norm = BiasNorm(out_channels)
|
||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||
|
||||
|
||||
@ -2018,143 +1943,6 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class AttentionCombine(nn.Module):
|
||||
"""
|
||||
This module combines a list of Tensors, all with the same shape, to
|
||||
produce a single output of that same shape which, in training time,
|
||||
is a random combination of all the inputs; but which in test time
|
||||
will be just the last input.
|
||||
|
||||
All but the last input will have a linear transform before we
|
||||
randomly combine them; these linear transforms will be initialized
|
||||
to the identity transform.
|
||||
|
||||
The idea is that the list of Tensors will be a list of outputs of multiple
|
||||
zipformer layers. This has a similar effect as iterated loss. (See:
|
||||
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
||||
NETWORKS).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
num_inputs: int,
|
||||
random_prob: float = 0.25,
|
||||
single_prob: float = 0.333,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_channels:
|
||||
the number of channels
|
||||
num_inputs:
|
||||
The number of tensor inputs, which equals the number of layers'
|
||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||
random_prob:
|
||||
the probability with which we apply a nontrivial mask, in training
|
||||
mode.
|
||||
single_prob:
|
||||
the probability with which we mask to allow just a single
|
||||
module's output (in training)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.random_prob = random_prob
|
||||
self.single_prob = single_prob
|
||||
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
|
||||
num_inputs))
|
||||
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
||||
|
||||
self.name = None # will be set from training code
|
||||
assert 0 <= random_prob <= 1, random_prob
|
||||
assert 0 <= single_prob <= 1, single_prob
|
||||
|
||||
|
||||
|
||||
def forward(self, inputs: List[Tensor]) -> Tensor:
|
||||
"""Forward function.
|
||||
Args:
|
||||
inputs:
|
||||
A list of Tensor, e.g. from various layers of a transformer.
|
||||
All must be the same shape, of (*, num_channels)
|
||||
Returns:
|
||||
A Tensor of shape (*, num_channels). In test mode
|
||||
this is just the final input.
|
||||
"""
|
||||
num_inputs = self.weight.shape[1]
|
||||
assert len(inputs) == num_inputs
|
||||
|
||||
# Shape of weights: (*, num_inputs)
|
||||
num_channels = inputs[0].shape[-1]
|
||||
num_frames = inputs[0].numel() // num_channels
|
||||
|
||||
ndim = inputs[0].ndim
|
||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
|
||||
(num_frames, num_channels, num_inputs)
|
||||
)
|
||||
|
||||
scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
|
||||
|
||||
if random.random() < 0.002:
|
||||
logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}")
|
||||
|
||||
if self.training:
|
||||
# random masking..
|
||||
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
||||
size=(num_frames,), device=scores.device).unsqueeze(1)
|
||||
# mask will have rows like: [ False, False, False, True, True, .. ]
|
||||
arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand(
|
||||
num_frames, num_inputs)
|
||||
mask = arange >= mask_start
|
||||
|
||||
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
|
||||
device=scores.device) < self.single_prob,
|
||||
mask_start < num_inputs)
|
||||
single_prob_mask = torch.logical_and(apply_single_prob,
|
||||
arange < mask_start - 1)
|
||||
|
||||
mask = torch.logical_or(mask,
|
||||
single_prob_mask)
|
||||
|
||||
scores = scores.masked_fill(mask, float('-inf'))
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
scores = penalize_abs_values_gt(scores,
|
||||
limit=10.0,
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
|
||||
weights = scores.softmax(dim=1)
|
||||
|
||||
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
|
||||
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
|
||||
# ans: (*, num_channels)
|
||||
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# for testing only...
|
||||
print("Weights = ", weights.reshape(num_frames, num_inputs))
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
def _test_random_combine():
|
||||
print("_test_random_combine()")
|
||||
num_inputs = 3
|
||||
num_channels = 50
|
||||
m = AttentionCombine(
|
||||
num_channels=num_channels,
|
||||
num_inputs=num_inputs,
|
||||
random_prob=0.5,
|
||||
single_prob=0.0)
|
||||
|
||||
|
||||
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
||||
|
||||
y = m(x)
|
||||
assert y.shape == x[0].shape
|
||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||
|
||||
|
||||
def _test_zipformer_main(causal: bool = False):
|
||||
@ -2164,7 +1952,7 @@ def _test_zipformer_main(causal: bool = False):
|
||||
feature_dim = 50
|
||||
# Just make sure the forward pass runs.
|
||||
|
||||
c = Zipformer(
|
||||
c = Zipformer2(
|
||||
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
|
||||
causal=causal,
|
||||
chunk_size=(4,) if causal else (-1,),
|
||||
@ -2191,6 +1979,5 @@ if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_random_combine()
|
||||
_test_zipformer_main(False)
|
||||
_test_zipformer_main(True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user