Remove some more unused code; rename BasicNorm->BiasNorm, Zipformer->Zipformer2

This commit is contained in:
Daniel Povey 2023-03-06 14:24:09 +08:00
parent 3424b60d8f
commit f59da65d82
3 changed files with 35 additions and 484 deletions

View File

@ -325,7 +325,7 @@ class MaxEigLimiterFunction(torch.autograd.Function):
class BasicNormFunction(torch.autograd.Function): class BiasNormFunction(torch.autograd.Function):
# This computes: # This computes:
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
# return (x - bias) * scales # return (x - bias) * scales
@ -368,7 +368,7 @@ class BasicNormFunction(torch.autograd.Function):
class BasicNorm(torch.nn.Module): class BiasNorm(torch.nn.Module):
""" """
This is intended to be a simpler, and hopefully cheaper, replacement for This is intended to be a simpler, and hopefully cheaper, replacement for
LayerNorm. The observation this is based on, is that Transformer-type LayerNorm. The observation this is based on, is that Transformer-type
@ -378,9 +378,10 @@ class BasicNorm(torch.nn.Module):
on the other (useful) features. Presumably the weight and bias of the on the other (useful) features. Presumably the weight and bias of the
LayerNorm are required to allow it to do this. LayerNorm are required to allow it to do this.
So the idea is to introduce this large constant value as an explicit Instead, we give the BiasNorm a trainable bias that it can use when
parameter, that takes the role of the "eps" in LayerNorm, so the network computing the scale for normalization. We also give it a (scalar)
doesn't have to do this trick. We make the "eps" learnable. trainable scale on the output.
Args: Args:
num_channels: the number of channels, e.g. 512. num_channels: the number of channels, e.g. 512.
@ -397,7 +398,6 @@ class BasicNorm(torch.nn.Module):
than the input of this module to be required to be stored for the than the input of this module to be required to be stored for the
backprop. backprop.
""" """
def __init__( def __init__(
self, self,
num_channels: int, num_channels: int,
@ -407,7 +407,7 @@ class BasicNorm(torch.nn.Module):
log_scale_max: float = 1.5, log_scale_max: float = 1.5,
store_output_for_backprop: bool = False store_output_for_backprop: bool = False
) -> None: ) -> None:
super(BasicNorm, self).__init__() super(BiasNorm, self).__init__()
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.log_scale = nn.Parameter(torch.tensor(log_scale)) self.log_scale = nn.Parameter(torch.tensor(log_scale))
@ -438,245 +438,9 @@ class BasicNorm(torch.nn.Module):
max=float(self.log_scale_max), max=float(self.log_scale_max),
training=self.training) training=self.training)
return BasicNormFunction.apply(x, self.bias, log_scale, return BiasNormFunction.apply(x, self.bias, log_scale,
self.channel_dim, self.channel_dim,
self.store_output_for_backprop) self.store_output_for_backprop)
class PositiveConv1d(nn.Conv1d):
"""
A modified form of nn.Conv1d where the weight parameters are constrained
to be positive and there is no bias.
"""
def __init__(
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
**kwargs):
super().__init__(*args, **kwargs, bias=False)
self.min = min
self.max = max
# initialize weight to all positive values.
with torch.no_grad():
self.weight[:] = 1.0 / self.weight[0][0].numel()
def forward(self, input: Tensor) -> Tensor:
"""
Forward function. Input and returned tensor have shape:
(N, C, H)
i.e. (batch_size, num_channels, height)
"""
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
training=self.training)
# make absolutely sure there are no negative values. For parameter-averaging-related
# reasons, we prefer to also use limit_param_value to make sure the weights stay
# positive.
weight = weight.abs()
if self.padding_mode != 'zeros':
return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_single(0), self.dilation, self.groups)
return F.conv1d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class ConvNorm1d(torch.nn.Module):
"""
This is like BasicNorm except the denominator is summed over time using
convolution with positive weights.
Args:
num_channels: the number of channels, e.g. 512.
eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
eps_min: float
eps_max: float
"""
def __init__(
self,
num_channels: int,
eps: float = 0.25,
learn_eps: bool = True,
eps_min: float = -3.0,
eps_max: float = 3.0,
conv_min: float = 0.001,
conv_max: float = 1.0,
kernel_size: int = 15,
) -> None:
super().__init__()
self.num_channels = num_channels
if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else:
self.register_buffer("eps", torch.tensor(eps).log().detach())
self.eps_min = eps_min
self.eps_max = eps_max
pad = kernel_size // 2
# it has bias=False.
self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad,
min=conv_min, max=conv_max)
def forward(self, x: Tensor,
src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
"""
x shape: (N, C, T)
src_key_padding_mask: the mask for the src keys per batch (optional):
(N, T), contains True in masked positions.
"""
assert x.ndim == 3 and x.shape[1] == self.num_channels
eps = self.eps
if self.training and random.random() < 0.25:
# with probability 0.25, in training mode, clamp eps between the min
# and max; this will encourage it to learn parameters within the
# allowed range by making parameters that are outside the allowed
# range noisy.
# gradients to allow the parameter to get back into the allowed
# region if it happens to exit it.
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
# sqnorms: (N, 1, T)
sqnorms = (
torch.mean(x ** 2, dim=1, keepdim=True)
)
# 'counts' is a mechanism to correct for edge effects.
counts = torch.ones_like(sqnorms)
if src_key_padding_mask is not None:
counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
sqnorms = sqnorms * counts
sqnorms = self.conv(sqnorms)
# the clamping is to avoid division by zero for padding frames.
counts = torch.clamp(self.conv(counts), min=0.01)
# scales: (N, 1, T)
scales = (sqnorms / counts + eps.exp()) ** -0.5 #
return x * scales
class PositiveConv2d(nn.Conv2d):
"""
A modified form of nn.Conv2d where the weight parameters are constrained
to be positive and there is no bias.
"""
def __init__(
self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0,
**kwargs):
super().__init__(*args, **kwargs, bias=False)
self.min = min
self.max = max
# initialize weight to all positive values.
with torch.no_grad():
self.weight[:] = 1.0 / self.weight[0][0].numel()
def forward(self, input: Tensor) -> Tensor:
"""
Forward function. Input and returned tensor have shape:
(N, C, H, W)
i.e. (batch_size, num_channels, height, width)
"""
weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max),
training=self.training)
# make absolutely sure there are no negative values. For parameter-averaging-related
# reasons, we prefer to also use limit_param_value to make sure the weights stay
# positive.
weight = weight.abs()
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
_pair(0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class ConvNorm2d(torch.nn.Module):
"""
This is like BasicNorm except the denominator is summed over time using
convolution with positive weights.
Args:
num_channels: the number of channels, e.g. 512.
eps: the initial "epsilon" that we add as ballast in:
scale = ((input_vec**2).mean() + epsilon)**-0.5
Note: our epsilon is actually large, but we keep the name
to indicate the connection with conventional LayerNorm.
learn_eps: if true, we learn epsilon; if false, we keep it
at the initial value.
eps_min: float
eps_max: float
"""
def __init__(
self,
num_channels: int,
eps: float = 0.25,
learn_eps: bool = True,
eps_min: float = -3.0,
eps_max: float = 3.0,
conv_min: float = 0.001,
conv_max: float = 1.0,
kernel_size: Tuple[int, int] = (3, 3),
) -> None:
super().__init__()
self.num_channels = num_channels
if learn_eps:
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
else:
self.register_buffer("eps", torch.tensor(eps).log().detach())
self.eps_min = eps_min
self.eps_max = eps_max
pad = (kernel_size[0] // 2, kernel_size[1] // 2)
# it has bias=False.
self.conv = PositiveConv2d(1, 1, kernel_size=kernel_size, padding=pad,
min=conv_min, max=conv_max)
def forward(self, x: Tensor) -> Tensor:
"""
x shape: (N, C, H, W)
"""
assert x.ndim == 4 and x.shape[1] == self.num_channels
eps = self.eps
if self.training and random.random() < 0.25:
# with probability 0.25, in training mode, clamp eps between the min
# and max; this will encourage it to learn parameters within the
# allowed range by making parameters that are outside the allowed
# range noisy.
# gradients to allow the parameter to get back into the allowed
# region if it happens to exit it.
eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max)
# sqnorms: (N, 1, H, W)
sqnorms = (
torch.mean(x ** 2, dim=1, keepdim=True)
)
# 'counts' is a mechanism to correct for edge effects.
# TODO: key-padding mask
counts = torch.ones_like(sqnorms)
#if src_key_padding_mask is not None:
# counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0)
#sqnorms = sqnorms * counts
sqnorms = self.conv(sqnorms)
# the clamping is to avoid division by zero for padding frames.
counts = torch.clamp(self.conv(counts), min=0.01)
# scales: (N, 1, H, W)
scales = (sqnorms / counts + eps.exp()) ** -0.5
return x * scales

View File

@ -60,7 +60,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from zipformer import Zipformer from zipformer import Zipformer2
from scaling import ScheduledFloat from scaling import ScheduledFloat
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -536,7 +536,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Zipformer and Transformer # TODO: We can add an option to switch between Zipformer and Transformer
def to_int_tuple(s: str): def to_int_tuple(s: str):
return tuple(map(int, s.split(','))) return tuple(map(int, s.split(',')))
encoder = Zipformer( encoder = Zipformer2(
num_features=params.feature_dim, num_features=params.feature_dim,
output_downsampling_factor=2, output_downsampling_factor=2,
downsampling_factor=to_int_tuple(params.downsampling_factor), downsampling_factor=to_int_tuple(params.downsampling_factor),

View File

@ -26,9 +26,7 @@ import random
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
Balancer, Balancer,
BasicNorm, BiasNorm,
ConvNorm1d,
ConvNorm2d,
Dropout2, Dropout2,
Dropout3, Dropout3,
SwooshL, SwooshL,
@ -53,7 +51,7 @@ from icefall.utils import make_pad_mask
from icefall.dist import get_rank from icefall.dist import get_rank
class Zipformer(EncoderInterface): class Zipformer2(EncoderInterface):
""" """
Args: Args:
@ -127,7 +125,7 @@ class Zipformer(EncoderInterface):
chunk_size: Tuple[int] = [-1], chunk_size: Tuple[int] = [-1],
left_context_frames: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1],
) -> None: ) -> None:
super(Zipformer, self).__init__() super(Zipformer2, self).__init__()
if dropout is None: if dropout is None:
dropout = ScheduledFloat((0.0, 0.3), dropout = ScheduledFloat((0.0, 0.3),
@ -185,13 +183,13 @@ class Zipformer(EncoderInterface):
dropout=dropout) dropout=dropout)
# each one will be ZipformerEncoder or DownsampledZipformerEncoder # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
encoders = [] encoders = []
num_encoders = len(downsampling_factor) num_encoders = len(downsampling_factor)
for i in range(num_encoders): for i in range(num_encoders):
encoder_layer = ZipformerEncoderLayer( encoder_layer = Zipformer2EncoderLayer(
embed_dim=encoder_dim[i], embed_dim=encoder_dim[i],
pos_dim=pos_dim, pos_dim=pos_dim,
num_heads=num_heads[i], num_heads=num_heads[i],
@ -206,7 +204,7 @@ class Zipformer(EncoderInterface):
# For the segment of the warmup period, we let the Conv2dSubsampling # For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders. # layer learn something. Then we start to warm up the other encoders.
encoder = ZipformerEncoder( encoder = Zipformer2Encoder(
encoder_layer, encoder_layer,
num_encoder_layers[i], num_encoder_layers[i],
pos_dim=pos_dim, pos_dim=pos_dim,
@ -218,7 +216,7 @@ class Zipformer(EncoderInterface):
) )
if downsampling_factor[i] != 1: if downsampling_factor[i] != 1:
encoder = DownsampledZipformerEncoder( encoder = DownsampledZipformer2Encoder(
encoder, encoder,
dim=encoder_dim[i], dim=encoder_dim[i],
downsample=downsampling_factor[i], downsample=downsampling_factor[i],
@ -492,7 +490,7 @@ def _balancer_schedule(min_prob: float):
class ZipformerEncoderLayer(nn.Module): class Zipformer2EncoderLayer(nn.Module):
""" """
Args: Args:
embed_dim: the number of expected features in the input (required). embed_dim: the number of expected features in the input (required).
@ -502,7 +500,7 @@ class ZipformerEncoderLayer(nn.Module):
cnn_module_kernel (int): Kernel size of convolution module. cnn_module_kernel (int): Kernel size of convolution module.
Examples:: Examples::
>>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8) >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
>>> src = torch.rand(10, 32, 512) >>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512) >>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb) >>> out = encoder_layer(src, pos_emb)
@ -530,7 +528,7 @@ class ZipformerEncoderLayer(nn.Module):
bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0),
bypass_max: FloatLike = 1.0, bypass_max: FloatLike = 1.0,
) -> None: ) -> None:
super(ZipformerEncoderLayer, self).__init__() super(Zipformer2EncoderLayer, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
# probability of skipping the entire layer. # probability of skipping the entire layer.
@ -578,7 +576,7 @@ class ZipformerEncoderLayer(nn.Module):
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) #self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
self.norm = BasicNorm(embed_dim) self.norm = BiasNorm(embed_dim)
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
@ -760,17 +758,17 @@ class ZipformerEncoderLayer(nn.Module):
return src, attn_weights return src, attn_weights
class ZipformerEncoder(nn.Module): class Zipformer2Encoder(nn.Module):
r"""ZipformerEncoder is a stack of N encoder layers r"""Zipformer2Encoder is a stack of N encoder layers
Args: Args:
encoder_layer: an instance of the ZipformerEncoderLayer() class (required). encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required). num_layers: the number of sub-encoder-layers in the encoder (required).
pos_dim: the dimension for the relative positional encoding pos_dim: the dimension for the relative positional encoding
Examples:: Examples::
>>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8) >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
>>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512) >>> src = torch.rand(10, 32, 512)
>>> out = zipformer_encoder(src) >>> out = zipformer_encoder(src)
""" """
@ -856,9 +854,9 @@ class ZipformerEncoder(nn.Module):
return output return output
class DownsampledZipformerEncoder(nn.Module): class DownsampledZipformer2Encoder(nn.Module):
r""" r"""
DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output, and combined after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input. with the origin input, so that the output has the same shape as the input.
""" """
@ -867,7 +865,7 @@ class DownsampledZipformerEncoder(nn.Module):
dim: int, dim: int,
downsample: int, downsample: int,
dropout: FloatLike): dropout: FloatLike):
super(DownsampledZipformerEncoder, self).__init__() super(DownsampledZipformer2Encoder, self).__init__()
self.downsample_factor = downsample self.downsample_factor = downsample
self.downsample = SimpleDownsample(dim, self.downsample = SimpleDownsample(dim,
downsample, dropout) downsample, dropout)
@ -1031,79 +1029,6 @@ class SimpleCombiner(torch.nn.Module):
class SmallConvolutionModule(nn.Module):
"""Part of Zipformer model: a small version of the Convolution module that uses a small kernel.
Inspired by convnext (i.e. have the depthwise conv first.)
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int,
hidden_dim: int = 128,
kernel_size: int = 5,
causal: bool = False,
) -> None:
super().__init__()
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
channels=channels,
kernel_size=kernel_size) if causal else nn.Conv1d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=kernel_size // 2)
self.linear1 = nn.Linear(
channels, hidden_dim)
# balancer and activation as tuned for ConvolutionModule.
self.balancer = Balancer(
hidden_dim, channel_dim=-1,
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0,
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
max_abs=10.0,
)
self.activation = SwooshR()
self.linear2 = ScaledLinear(hidden_dim, channels,
initial_scale=0.05)
def forward(self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
src_key_padding_mask: the mask for the src keys per batch (optional):
(batch, #time), contains bool in masked positions.
Returns:
Tensor: Output tensor (#time, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x) # (batch, channels, time)
x = x.permute(2, 0, 1) # (time, batch, channels)
x = self.linear1(x) # (time, batch, hidden_dim)
x = self.balancer(x)
x = self.activation(x)
x = self.linear2(x) # (time, batch, channels)
return x
class CompactRelPositionalEncoding(torch.nn.Module): class CompactRelPositionalEncoding(torch.nn.Module):
""" """
Relative positional encoding module. This version is "compact" meaning it is able to encode Relative positional encoding module. This version is "compact" meaning it is able to encode
@ -1502,7 +1427,7 @@ class SelfAttention(nn.Module):
class FeedforwardModule(nn.Module): class FeedforwardModule(nn.Module):
"""Feedforward module in Zipformer model. """Feedforward module in Zipformer2 model.
""" """
def __init__(self, def __init__(self,
embed_dim: int, embed_dim: int,
@ -1645,7 +1570,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
class ConvolutionModule(nn.Module): class ConvolutionModule(nn.Module):
"""ConvolutionModule in Zipformer model. """ConvolutionModule in Zipformer2 model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
Args: Args:
@ -1983,7 +1908,7 @@ class Conv2dSubsampling(nn.Module):
# max_log_eps=0.0 is to prevent both eps and the output of self.out from # max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom. # getting large, there is an unnecessary degree of freedom.
self.out_norm = BasicNorm(out_channels) self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1) self.dropout = Dropout3(dropout, shared_dim=1)
@ -2018,143 +1943,6 @@ class Conv2dSubsampling(nn.Module):
x = self.dropout(x) x = self.dropout(x)
return x return x
class AttentionCombine(nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
All but the last input will have a linear transform before we
randomly combine them; these linear transforms will be initialized
to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple
zipformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(
self,
num_channels: int,
num_inputs: int,
random_prob: float = 0.25,
single_prob: float = 0.333,
) -> None:
"""
Args:
num_channels:
the number of channels
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
random_prob:
the probability with which we apply a nontrivial mask, in training
mode.
single_prob:
the probability with which we mask to allow just a single
module's output (in training)
"""
super().__init__()
self.random_prob = random_prob
self.single_prob = single_prob
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
num_inputs))
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
self.name = None # will be set from training code
assert 0 <= random_prob <= 1, random_prob
assert 0 <= single_prob <= 1, single_prob
def forward(self, inputs: List[Tensor]) -> Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.weight.shape[1]
assert len(inputs) == num_inputs
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)
scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
if random.random() < 0.002:
logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}")
if self.training:
# random masking..
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
size=(num_frames,), device=scores.device).unsqueeze(1)
# mask will have rows like: [ False, False, False, True, True, .. ]
arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand(
num_frames, num_inputs)
mask = arange >= mask_start
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
device=scores.device) < self.single_prob,
mask_start < num_inputs)
single_prob_mask = torch.logical_and(apply_single_prob,
arange < mask_start - 1)
mask = torch.logical_or(mask,
single_prob_mask)
scores = scores.masked_fill(mask, float('-inf'))
if self.training and random.random() < 0.1:
scores = penalize_abs_values_gt(scores,
limit=10.0,
penalty=1.0e-04,
name=self.name)
weights = scores.softmax(dim=1)
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
# ans: (*, num_channels)
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
if __name__ == "__main__":
# for testing only...
print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _test_random_combine():
print("_test_random_combine()")
num_inputs = 3
num_channels = 50
m = AttentionCombine(
num_channels=num_channels,
num_inputs=num_inputs,
random_prob=0.5,
single_prob=0.0)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_zipformer_main(causal: bool = False): def _test_zipformer_main(causal: bool = False):
@ -2164,7 +1952,7 @@ def _test_zipformer_main(causal: bool = False):
feature_dim = 50 feature_dim = 50
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
c = Zipformer( c = Zipformer2(
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4), num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
causal=causal, causal=causal,
chunk_size=(4,) if causal else (-1,), chunk_size=(4,) if causal else (-1,),
@ -2191,6 +1979,5 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_random_combine()
_test_zipformer_main(False) _test_zipformer_main(False)
_test_zipformer_main(True) _test_zipformer_main(True)