diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 0f7307c56..dc0263d1f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -325,7 +325,7 @@ class MaxEigLimiterFunction(torch.autograd.Function): -class BasicNormFunction(torch.autograd.Function): +class BiasNormFunction(torch.autograd.Function): # This computes: # scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp() # return (x - bias) * scales @@ -368,7 +368,7 @@ class BasicNormFunction(torch.autograd.Function): -class BasicNorm(torch.nn.Module): +class BiasNorm(torch.nn.Module): """ This is intended to be a simpler, and hopefully cheaper, replacement for LayerNorm. The observation this is based on, is that Transformer-type @@ -378,9 +378,10 @@ class BasicNorm(torch.nn.Module): on the other (useful) features. Presumably the weight and bias of the LayerNorm are required to allow it to do this. - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. + Instead, we give the BiasNorm a trainable bias that it can use when + computing the scale for normalization. We also give it a (scalar) + trainable scale on the output. + Args: num_channels: the number of channels, e.g. 512. @@ -397,7 +398,6 @@ class BasicNorm(torch.nn.Module): than the input of this module to be required to be stored for the backprop. """ - def __init__( self, num_channels: int, @@ -407,7 +407,7 @@ class BasicNorm(torch.nn.Module): log_scale_max: float = 1.5, store_output_for_backprop: bool = False ) -> None: - super(BasicNorm, self).__init__() + super(BiasNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim self.log_scale = nn.Parameter(torch.tensor(log_scale)) @@ -438,245 +438,9 @@ class BasicNorm(torch.nn.Module): max=float(self.log_scale_max), training=self.training) - return BasicNormFunction.apply(x, self.bias, log_scale, - self.channel_dim, - self.store_output_for_backprop) - - - -class PositiveConv1d(nn.Conv1d): - """ - A modified form of nn.Conv1d where the weight parameters are constrained - to be positive and there is no bias. - """ - def __init__( - self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0, - **kwargs): - super().__init__(*args, **kwargs, bias=False) - self.min = min - self.max = max - - # initialize weight to all positive values. - with torch.no_grad(): - self.weight[:] = 1.0 / self.weight[0][0].numel() - - def forward(self, input: Tensor) -> Tensor: - """ - Forward function. Input and returned tensor have shape: - (N, C, H) - i.e. (batch_size, num_channels, height) - """ - weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max), - training=self.training) - # make absolutely sure there are no negative values. For parameter-averaging-related - # reasons, we prefer to also use limit_param_value to make sure the weights stay - # positive. - weight = weight.abs() - - if self.padding_mode != 'zeros': - return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.bias, self.stride, - _single(0), self.dilation, self.groups) - return F.conv1d(input, weight, self.bias, self.stride, - self.padding, self.dilation, self.groups) - - - -class ConvNorm1d(torch.nn.Module): - """ - This is like BasicNorm except the denominator is summed over time using - convolution with positive weights. - - - Args: - num_channels: the number of channels, e.g. 512. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_min: float - eps_max: float - """ - - def __init__( - self, - num_channels: int, - eps: float = 0.25, - learn_eps: bool = True, - eps_min: float = -3.0, - eps_max: float = 3.0, - conv_min: float = 0.001, - conv_max: float = 1.0, - kernel_size: int = 15, - ) -> None: - super().__init__() - self.num_channels = num_channels - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - self.eps_min = eps_min - self.eps_max = eps_max - pad = kernel_size // 2 - # it has bias=False. - self.conv = PositiveConv1d(1, 1, kernel_size=kernel_size, padding=pad, - min=conv_min, max=conv_max) - - - def forward(self, x: Tensor, - src_key_padding_mask: Optional[Tensor] = None) -> Tensor: - """ - x shape: (N, C, T) - - src_key_padding_mask: the mask for the src keys per batch (optional): - (N, T), contains True in masked positions. - - """ - assert x.ndim == 3 and x.shape[1] == self.num_channels - eps = self.eps - if self.training and random.random() < 0.25: - # with probability 0.25, in training mode, clamp eps between the min - # and max; this will encourage it to learn parameters within the - # allowed range by making parameters that are outside the allowed - # range noisy. - - # gradients to allow the parameter to get back into the allowed - # region if it happens to exit it. - eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max) - - # sqnorms: (N, 1, T) - sqnorms = ( - torch.mean(x ** 2, dim=1, keepdim=True) - ) - # 'counts' is a mechanism to correct for edge effects. - counts = torch.ones_like(sqnorms) - if src_key_padding_mask is not None: - counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0) - sqnorms = sqnorms * counts - sqnorms = self.conv(sqnorms) - # the clamping is to avoid division by zero for padding frames. - counts = torch.clamp(self.conv(counts), min=0.01) - # scales: (N, 1, T) - scales = (sqnorms / counts + eps.exp()) ** -0.5 # - return x * scales - - -class PositiveConv2d(nn.Conv2d): - """ - A modified form of nn.Conv2d where the weight parameters are constrained - to be positive and there is no bias. - """ - def __init__( - self, *args, min: FloatLike = 0.01, max: FloatLike = 1.0, - **kwargs): - super().__init__(*args, **kwargs, bias=False) - self.min = min - self.max = max - - # initialize weight to all positive values. - with torch.no_grad(): - self.weight[:] = 1.0 / self.weight[0][0].numel() - - def forward(self, input: Tensor) -> Tensor: - """ - Forward function. Input and returned tensor have shape: - (N, C, H, W) - i.e. (batch_size, num_channels, height, width) - """ - weight = limit_param_value(self.weight, min=float(self.min), max=float(self.max), - training=self.training) - # make absolutely sure there are no negative values. For parameter-averaging-related - # reasons, we prefer to also use limit_param_value to make sure the weights stay - # positive. - weight = weight.abs() - - if self.padding_mode != 'zeros': - return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), - weight, self.bias, self.stride, - _pair(0), self.dilation, self.groups) - return F.conv2d(input, weight, self.bias, self.stride, - self.padding, self.dilation, self.groups) - - -class ConvNorm2d(torch.nn.Module): - """ - This is like BasicNorm except the denominator is summed over time using - convolution with positive weights. - - - Args: - num_channels: the number of channels, e.g. 512. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_min: float - eps_max: float - """ - - def __init__( - self, - num_channels: int, - eps: float = 0.25, - learn_eps: bool = True, - eps_min: float = -3.0, - eps_max: float = 3.0, - conv_min: float = 0.001, - conv_max: float = 1.0, - kernel_size: Tuple[int, int] = (3, 3), - ) -> None: - super().__init__() - self.num_channels = num_channels - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - self.eps_min = eps_min - self.eps_max = eps_max - pad = (kernel_size[0] // 2, kernel_size[1] // 2) - # it has bias=False. - self.conv = PositiveConv2d(1, 1, kernel_size=kernel_size, padding=pad, - min=conv_min, max=conv_max) - - - def forward(self, x: Tensor) -> Tensor: - """ - x shape: (N, C, H, W) - """ - assert x.ndim == 4 and x.shape[1] == self.num_channels - eps = self.eps - if self.training and random.random() < 0.25: - # with probability 0.25, in training mode, clamp eps between the min - # and max; this will encourage it to learn parameters within the - # allowed range by making parameters that are outside the allowed - # range noisy. - - # gradients to allow the parameter to get back into the allowed - # region if it happens to exit it. - eps = torch.clamp(eps, min=self.eps_min, max=self.eps_max) - - # sqnorms: (N, 1, H, W) - sqnorms = ( - torch.mean(x ** 2, dim=1, keepdim=True) - ) - # 'counts' is a mechanism to correct for edge effects. - # TODO: key-padding mask - - counts = torch.ones_like(sqnorms) - #if src_key_padding_mask is not None: - # counts = counts.masked_fill_(src_key_padding_mask.unsqueeze(1), 0.0) - #sqnorms = sqnorms * counts - sqnorms = self.conv(sqnorms) - # the clamping is to avoid division by zero for padding frames. - counts = torch.clamp(self.conv(counts), min=0.01) - # scales: (N, 1, H, W) - scales = (sqnorms / counts + eps.exp()) ** -0.5 - return x * scales - + return BiasNormFunction.apply(x, self.bias, log_scale, + self.channel_dim, + self.store_output_for_backprop) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 0af974f5f..b58743944 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -60,7 +60,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from zipformer import Zipformer +from zipformer import Zipformer2 from scaling import ScheduledFloat from decoder import Decoder from joiner import Joiner @@ -536,7 +536,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Zipformer and Transformer def to_int_tuple(s: str): return tuple(map(int, s.split(','))) - encoder = Zipformer( + encoder = Zipformer2( num_features=params.feature_dim, output_downsampling_factor=2, downsampling_factor=to_int_tuple(params.downsampling_factor), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fe8ebb8b8..3ae6b3ce0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -26,9 +26,7 @@ import random from encoder_interface import EncoderInterface from scaling import ( Balancer, - BasicNorm, - ConvNorm1d, - ConvNorm2d, + BiasNorm, Dropout2, Dropout3, SwooshL, @@ -53,7 +51,7 @@ from icefall.utils import make_pad_mask from icefall.dist import get_rank -class Zipformer(EncoderInterface): +class Zipformer2(EncoderInterface): """ Args: @@ -127,7 +125,7 @@ class Zipformer(EncoderInterface): chunk_size: Tuple[int] = [-1], left_context_frames: Tuple[int] = [-1], ) -> None: - super(Zipformer, self).__init__() + super(Zipformer2, self).__init__() if dropout is None: dropout = ScheduledFloat((0.0, 0.3), @@ -185,13 +183,13 @@ class Zipformer(EncoderInterface): dropout=dropout) - # each one will be ZipformerEncoder or DownsampledZipformerEncoder + # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder encoders = [] num_encoders = len(downsampling_factor) for i in range(num_encoders): - encoder_layer = ZipformerEncoderLayer( + encoder_layer = Zipformer2EncoderLayer( embed_dim=encoder_dim[i], pos_dim=pos_dim, num_heads=num_heads[i], @@ -206,7 +204,7 @@ class Zipformer(EncoderInterface): # For the segment of the warmup period, we let the Conv2dSubsampling # layer learn something. Then we start to warm up the other encoders. - encoder = ZipformerEncoder( + encoder = Zipformer2Encoder( encoder_layer, num_encoder_layers[i], pos_dim=pos_dim, @@ -218,7 +216,7 @@ class Zipformer(EncoderInterface): ) if downsampling_factor[i] != 1: - encoder = DownsampledZipformerEncoder( + encoder = DownsampledZipformer2Encoder( encoder, dim=encoder_dim[i], downsample=downsampling_factor[i], @@ -492,7 +490,7 @@ def _balancer_schedule(min_prob: float): -class ZipformerEncoderLayer(nn.Module): +class Zipformer2EncoderLayer(nn.Module): """ Args: embed_dim: the number of expected features in the input (required). @@ -502,7 +500,7 @@ class ZipformerEncoderLayer(nn.Module): cnn_module_kernel (int): Kernel size of convolution module. Examples:: - >>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8) + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) @@ -530,7 +528,7 @@ class ZipformerEncoderLayer(nn.Module): bypass_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.2), default=0), bypass_max: FloatLike = 1.0, ) -> None: - super(ZipformerEncoderLayer, self).__init__() + super(Zipformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim # probability of skipping the entire layer. @@ -578,7 +576,7 @@ class ZipformerEncoderLayer(nn.Module): #self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) - self.norm = BasicNorm(embed_dim) + self.norm = BiasNorm(embed_dim) self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) @@ -760,17 +758,17 @@ class ZipformerEncoderLayer(nn.Module): return src, attn_weights -class ZipformerEncoder(nn.Module): - r"""ZipformerEncoder is a stack of N encoder layers +class Zipformer2Encoder(nn.Module): + r"""Zipformer2Encoder is a stack of N encoder layers Args: - encoder_layer: an instance of the ZipformerEncoderLayer() class (required). + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). pos_dim: the dimension for the relative positional encoding Examples:: - >>> encoder_layer = ZipformerEncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) + >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ @@ -856,9 +854,9 @@ class ZipformerEncoder(nn.Module): return output -class DownsampledZipformerEncoder(nn.Module): +class DownsampledZipformer2Encoder(nn.Module): r""" - DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, + DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ @@ -867,7 +865,7 @@ class DownsampledZipformerEncoder(nn.Module): dim: int, downsample: int, dropout: FloatLike): - super(DownsampledZipformerEncoder, self).__init__() + super(DownsampledZipformer2Encoder, self).__init__() self.downsample_factor = downsample self.downsample = SimpleDownsample(dim, downsample, dropout) @@ -1031,79 +1029,6 @@ class SimpleCombiner(torch.nn.Module): -class SmallConvolutionModule(nn.Module): - """Part of Zipformer model: a small version of the Convolution module that uses a small kernel. - Inspired by convnext (i.e. have the depthwise conv first.) - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - """ - - def __init__( - self, channels: int, - hidden_dim: int = 128, - kernel_size: int = 5, - causal: bool = False, - ) -> None: - super().__init__() - - self.depthwise_conv = ChunkCausalDepthwiseConv1d( - channels=channels, - kernel_size=kernel_size) if causal else nn.Conv1d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=kernel_size // 2) - - self.linear1 = nn.Linear( - channels, hidden_dim) - - # balancer and activation as tuned for ConvolutionModule. - - self.balancer = Balancer( - hidden_dim, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.activation = SwooshR() - - self.linear2 = ScaledLinear(hidden_dim, channels, - initial_scale=0.05) - - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains bool in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - x = self.depthwise_conv(x) # (batch, channels, time) - x = x.permute(2, 0, 1) # (time, batch, channels) - x = self.linear1(x) # (time, batch, hidden_dim) - x = self.balancer(x) - x = self.activation(x) - x = self.linear2(x) # (time, batch, channels) - return x - - class CompactRelPositionalEncoding(torch.nn.Module): """ Relative positional encoding module. This version is "compact" meaning it is able to encode @@ -1502,7 +1427,7 @@ class SelfAttention(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model. + """Feedforward module in Zipformer2 model. """ def __init__(self, embed_dim: int, @@ -1645,7 +1570,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer model. + """ConvolutionModule in Zipformer2 model. Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py Args: @@ -1983,7 +1908,7 @@ class Conv2dSubsampling(nn.Module): # max_log_eps=0.0 is to prevent both eps and the output of self.out from # getting large, there is an unnecessary degree of freedom. - self.out_norm = BasicNorm(out_channels) + self.out_norm = BiasNorm(out_channels) self.dropout = Dropout3(dropout, shared_dim=1) @@ -2018,143 +1943,6 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x -class AttentionCombine(nn.Module): - """ - This module combines a list of Tensors, all with the same shape, to - produce a single output of that same shape which, in training time, - is a random combination of all the inputs; but which in test time - will be just the last input. - - All but the last input will have a linear transform before we - randomly combine them; these linear transforms will be initialized - to the identity transform. - - The idea is that the list of Tensors will be a list of outputs of multiple - zipformer layers. This has a similar effect as iterated loss. (See: - DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER - NETWORKS). - """ - - def __init__( - self, - num_channels: int, - num_inputs: int, - random_prob: float = 0.25, - single_prob: float = 0.333, - ) -> None: - """ - Args: - num_channels: - the number of channels - num_inputs: - The number of tensor inputs, which equals the number of layers' - outputs that are fed into this module. E.g. in an 18-layer neural - net if we output layers 16, 12, 18, num_inputs would be 3. - random_prob: - the probability with which we apply a nontrivial mask, in training - mode. - single_prob: - the probability with which we mask to allow just a single - module's output (in training) - """ - super().__init__() - - self.random_prob = random_prob - self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, - num_inputs)) - self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) - - self.name = None # will be set from training code - assert 0 <= random_prob <= 1, random_prob - assert 0 <= single_prob <= 1, single_prob - - - - def forward(self, inputs: List[Tensor]) -> Tensor: - """Forward function. - Args: - inputs: - A list of Tensor, e.g. from various layers of a transformer. - All must be the same shape, of (*, num_channels) - Returns: - A Tensor of shape (*, num_channels). In test mode - this is just the final input. - """ - num_inputs = self.weight.shape[1] - assert len(inputs) == num_inputs - - # Shape of weights: (*, num_inputs) - num_channels = inputs[0].shape[-1] - num_frames = inputs[0].numel() // num_channels - - ndim = inputs[0].ndim - # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(inputs, dim=ndim).reshape( - (num_frames, num_channels, num_inputs) - ) - - scores = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias - - if random.random() < 0.002: - logging.info(f"Average scores are {scores.softmax(dim=1).mean(dim=0)}") - - if self.training: - # random masking.. - mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=scores.device).unsqueeze(1) - # mask will have rows like: [ False, False, False, True, True, .. ] - arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( - num_frames, num_inputs) - mask = arange >= mask_start - - apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), - device=scores.device) < self.single_prob, - mask_start < num_inputs) - single_prob_mask = torch.logical_and(apply_single_prob, - arange < mask_start - 1) - - mask = torch.logical_or(mask, - single_prob_mask) - - scores = scores.masked_fill(mask, float('-inf')) - - if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, - limit=10.0, - penalty=1.0e-04, - name=self.name) - - weights = scores.softmax(dim=1) - - # (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1), - ans = torch.matmul(stacked_inputs, weights.unsqueeze(2)) - # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) - return ans - - - -def _test_random_combine(): - print("_test_random_combine()") - num_inputs = 3 - num_channels = 50 - m = AttentionCombine( - num_channels=num_channels, - num_inputs=num_inputs, - random_prob=0.5, - single_prob=0.0) - - - x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] - - y = m(x) - assert y.shape == x[0].shape - assert torch.allclose(y, x[0]) # .. since actually all ones. def _test_zipformer_main(causal: bool = False): @@ -2164,7 +1952,7 @@ def _test_zipformer_main(causal: bool = False): feature_dim = 50 # Just make sure the forward pass runs. - c = Zipformer( + c = Zipformer2( num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4), causal=causal, chunk_size=(4,) if causal else (-1,), @@ -2191,6 +1979,5 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) - _test_random_combine() _test_zipformer_main(False) _test_zipformer_main(True)