diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 3ac1c1a90..8696aa61a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -127,6 +127,7 @@ from icefall.checkpoint import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + make_pad_mask, setup_logger, store_transcripts, str2bool, @@ -365,9 +366,15 @@ def decode_one_batch( value=LOG_EPS, ) + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_out, encoder_out_lens = model.encoder( - x=feature, x_lens=feature_lens + x, x_lens, src_key_padding_mask ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) hyps = [] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index c11b78937..6d203e994 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -19,14 +19,11 @@ import k2 import torch import torch.nn as nn import random +import warnings from encoder_interface import EncoderInterface -from icefall.utils import add_sos -from scaling import ( - penalize_abs_values_gt, - ScaledLinear -) - +from icefall.utils import add_sos, make_pad_mask +from scaling import penalize_abs_values_gt, ScaledLinear class Transducer(nn.Module): @@ -36,6 +33,7 @@ class Transducer(nn.Module): def __init__( self, + encoder_embed: nn.Module, encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, @@ -46,6 +44,10 @@ class Transducer(nn.Module): ): """ Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. encoder: It is the transcription network in the paper. Its accepts two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). @@ -64,18 +66,22 @@ class Transducer(nn.Module): assert isinstance(encoder, EncoderInterface), type(encoder) assert hasattr(decoder, "blank_id") + self.encoder_embed = encoder_embed self.encoder = encoder self.decoder = decoder self.joiner = joiner self.simple_am_proj = ScaledLinear( - encoder_dim, vocab_size, initial_scale=0.25, + encoder_dim, + vocab_size, + initial_scale=0.25, ) self.simple_lm_proj = ScaledLinear( - decoder_dim, vocab_size, initial_scale=0.25, + decoder_dim, + vocab_size, + initial_scale=0.25, ) - def forward( self, x: torch.Tensor, @@ -119,7 +125,15 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(x_lens > 0) @@ -142,7 +156,9 @@ class Transducer(nn.Module): y_padded = y_padded.to(torch.int64) boundary = torch.zeros( - (x.size(0), 4), dtype=torch.int64, device=x.device + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -150,9 +166,9 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - #if self.training and random.random() < 0.25: + # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/subsampling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/subsampling.py new file mode 100644 index 000000000..cb6301e33 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/subsampling.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple +import warnings + +import torch +from torch import Tensor, nn +from scaling import ( + Balancer, + BiasNorm, + Dropout3, + FloatLike, + Optional, + ScaledConv2d, + ScaleGrad, + ScheduledFloat, + SwooshL, + SwooshR, + Whiten, +) + + +class ConvNeXt(nn.Module): + """ + Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf + """ + + def __init__( + self, + channels: int, + hidden_ratio: int = 3, + kernel_size: Tuple[int, int] = (7, 7), + layerdrop_rate: FloatLike = None, + ): + super().__init__() + padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) + hidden_channels = channels * hidden_ratio + if layerdrop_rate is None: + layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) + self.layerdrop_rate = layerdrop_rate + + self.depthwise_conv = nn.Conv2d( + in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=padding, + ) + + self.pointwise_conv1 = nn.Conv2d( + in_channels=channels, out_channels=hidden_channels, kernel_size=1 + ) + + self.hidden_balancer = Balancer( + hidden_channels, + channel_dim=1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0, + ) + + self.activation = SwooshL() + self.pointwise_conv2 = ScaledConv2d( + in_channels=hidden_channels, + out_channels=channels, + kernel_size=1, + initial_scale=0.01, + ) + + self.out_balancer = Balancer( + channels, + channel_dim=1, + min_positive=0.4, + max_positive=0.6, + min_abs=1.0, + max_abs=6.0, + ) + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01, + ) + + def forward(self, x: Tensor) -> Tensor: + if torch.jit.is_scripting() or not self.training: + return self.forward_internal(x) + layerdrop_rate = float(self.layerdrop_rate) + + if layerdrop_rate != 0.0: + batch_size = x.shape[0] + mask = ( + torch.rand( + (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device + ) + > layerdrop_rate + ) + else: + mask = None + # turns out this caching idea does not work with --world-size > 1 + # return caching_eval(self.forward_internal, x, mask) + return self.forward_internal(x, mask) + + def forward_internal( + self, x: Tensor, layer_skip_mask: Optional[Tensor] = None + ) -> Tensor: + """ + x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) + + The returned value has the same shape as x. + """ + bypass = x + x = self.depthwise_conv(x) + x = self.pointwise_conv1(x) + x = self.hidden_balancer(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + + if layer_skip_mask is not None: + x = x * layer_skip_mask + + x = bypass + x + x = self.out_balancer(x) + x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last + x = self.out_whiten(x) + x = x.transpose(1, 3) # (N, C, H, W) + + return x + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/2 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = (T-3)//2 - 2 == (T-7)//2 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + dropout: FloatLike = 0.1, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, (T-3)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + bottleneck: + bottleneck dimension for 1d squeeze-excite + """ + assert in_channels >= 7 + super().__init__() + + # The ScaleGrad module is there to prevent the gradients + # w.r.t. the weight or bias of the first Conv2d module in self.conv from + # exceeding the range of fp16 when using automatic mixed precision (amp) + # training. (The second one is necessary to stop its bias from getting + # a too-large gradient). + + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=(0, 1), # (time, freq) + ), + ScaleGrad(0.2), + Balancer(layer1_channels, channel_dim=1, max_abs=1.0), + SwooshR(), + nn.Conv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + padding=0, + ), + Balancer(layer2_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + nn.Conv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=(1, 2), # (time, freq) + ), + Balancer(layer3_channels, channel_dim=1, max_abs=4.0), + SwooshR(), + ) + + # just one convnext layer + self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) + + out_width = (((in_channels - 1) // 2) - 1) // 2 + + self.out = nn.Linear(out_width * layer3_channels, out_channels) + # use a larger than normal grad_scale on this whitening module; there is + # only one such module, so there is not a concern about adding together + # many copies of this extra gradient term. + self.out_whiten = Whiten( + num_groups=1, + whitening_limit=ScheduledFloat( + (0.0, 4.0), (20000.0, 8.0), default=4.0 + ), + prob=(0.025, 0.25), + grad_scale=0.02, + ) + + # max_log_eps=0.0 is to prevent both eps and the output of self.out from + # getting large, there is an unnecessary degree of freedom. + self.out_norm = BiasNorm(out_channels) + self.dropout = Dropout3(dropout, shared_dim=1) + + def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + + Returns: + - a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + - output lengths, of shape (batch_size,) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) + # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite + # gradients. + x = self.conv(x) + x = self.convnext(x) + + # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + + x = x.transpose(1, 2).reshape(b, t, c * f) + # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) + + x = self.out(x) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_whiten(x) + x = self.out_norm(x) + x = self.dropout(x) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + x_lens = (x_lens - 7) // 2 + assert x.size(1) == x_lens.max().item() + + return x, x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 4b153bd8e..4a4288f61 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -64,6 +64,7 @@ from zipformer import Zipformer2 from scaling import ScheduledFloat from decoder import Decoder from joiner import Joiner +from subsampling import Conv2dSubsampling from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed @@ -525,29 +526,46 @@ def get_params() -> AttributeDict: return params +def _to_int_tuple(s: str): + return tuple(map(int, s.split(','))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + ) + return encoder_embed + + def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(','))) encoder = Zipformer2( - num_features=params.feature_dim, output_downsampling_factor=2, - downsampling_factor=to_int_tuple(params.downsampling_factor), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - encoder_dim=to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=to_int_tuple(params.query_head_dim), - pos_head_dim=to_int_tuple(params.pos_head_dim), - value_head_dim=to_int_tuple(params.value_head_dim), + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), pos_dim=params.pos_dim, - num_heads=to_int_tuple(params.num_heads), - feedforward_dim=to_int_tuple(params.feedforward_dim), - cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, causal=params.causal, - chunk_size=to_int_tuple(params.chunk_size), - left_context_frames=to_int_tuple(params.left_context_frames), + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), ) return encoder @@ -564,7 +582,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(max(params.encoder_dim.split(','))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -573,11 +591,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) model = Transducer( + encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 855d1c4d8..f5a40d402 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -18,7 +18,6 @@ import copy import math import warnings -import itertools from typing import List, Optional, Tuple, Union import logging import torch @@ -28,13 +27,8 @@ from scaling import ( Balancer, BiasNorm, Dropout2, - Dropout3, - SwooshL, - SwooshR, ChunkCausalDepthwiseConv1d, ActivationDropoutAndLinear, - ScaledConv1d, - ScaledConv2d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. @@ -44,13 +38,9 @@ from scaling import ( FloatLike, limit_param_value, convert_num_channels, - ScaleGrad, ) from torch import Tensor, nn -from icefall.utils import make_pad_mask -from icefall.dist import get_rank - class Zipformer2(EncoderInterface): """ @@ -60,8 +50,6 @@ class Zipformer2(EncoderInterface): as downsampling_factor if they are single ints or one-element tuples. The length of downsampling_factor defines the number of stacks. - - num_features (int): Number of input features, e.g. 40. output_downsampling_factor (int): how much to downsample at the output. Note: we also downsample by a factor of 2 in the Conv2dSubsampling encoder. You should probably leave this at 2. @@ -104,7 +92,6 @@ class Zipformer2(EncoderInterface): """ def __init__( self, - num_features: int, output_downsampling_factor: int = 2, downsampling_factor: Tuple[int] = (2, 4), encoder_dim: Union[int, Tuple[int]] = 384, @@ -140,7 +127,6 @@ class Zipformer2(EncoderInterface): assert len(x) == len(downsampling_factor) and isinstance(x[0], int) return x - self.num_features = num_features # int self.output_downsampling_factor = output_downsampling_factor # int self.downsampling_factor = downsampling_factor # tuple self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple @@ -160,18 +146,6 @@ class Zipformer2(EncoderInterface): for u,d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - self.encoder_embed = Conv2dSubsampling(num_features, encoder_dim[0], - dropout=dropout) - - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder encoders = [] @@ -297,7 +271,9 @@ class Zipformer2(EncoderInterface): def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, + self, x: torch.Tensor, + x_lens: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -306,34 +282,15 @@ class Zipformer2(EncoderInterface): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. - chunk_size: Number of frames per chunk (only set this if causal == True). - Must divide all elements of downsampling_factor. At 50hz frame - rate, i.e. after encoder_embed. If not specified, no chunking. - left_context_chunks: Number of left-context chunks for each chunk (affects - attention mask); only set this if chunk_size specified. If -1, there - is no limit on the left context. If not -1, require: - left_context_chunks * context_size >= downsampling_factor[i] * - cnn_module_kernel[i] // 2. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. Returns: Return a tuple containing 2 tensors: - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. """ - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") - - x = self.encoder_embed(x) - - # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens - 7) // 2 - assert x.size(0) == lengths.max().item() - src_key_padding_mask = make_pad_mask(lengths) - outputs = [] feature_masks = self.get_feature_masks(x) @@ -379,9 +336,7 @@ class Zipformer2(EncoderInterface): assert self.output_downsampling_factor == 2 with warnings.catch_warnings(): warnings.simplefilter("ignore") - lengths = (lengths + 1) // 2 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + lengths = (x_lens + 1) // 2 return x, lengths @@ -700,12 +655,9 @@ class Zipformer2EncoderLayer(nn.Module): src_key_padding_mask=src_key_padding_mask), float(self.conv_skip_rate)) - src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), float(self.ff3_skip_rate)) - - src = self.balancer1(src) src = self.norm(src) @@ -1686,244 +1638,13 @@ class ScalarMultiply(nn.Module): return x * self.scale - -class ConvNeXt(nn.Module): - """ - Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf - """ - def __init__(self, - channels: int, - hidden_ratio: int = 3, - kernel_size: Tuple[int, int] = (7, 7), - layerdrop_rate: FloatLike = None): - super().__init__() - padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) - hidden_channels = channels * hidden_ratio - if layerdrop_rate is None: - layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) - self.layerdrop_rate = layerdrop_rate - - self.depthwise_conv = nn.Conv2d( - in_channels=channels, - out_channels=channels, - groups=channels, - kernel_size=kernel_size, - padding=padding) - - self.pointwise_conv1 = nn.Conv2d( - in_channels=channels, - out_channels=hidden_channels, - kernel_size=1) - - self.hidden_balancer = Balancer(hidden_channels, - channel_dim=1, - min_positive=0.3, - max_positive=1.0, - min_abs=0.75, - max_abs=5.0) - - self.activation = SwooshL() - self.pointwise_conv2 = ScaledConv2d( - in_channels=hidden_channels, - out_channels=channels, - kernel_size=1, - initial_scale=0.01) - - self.out_balancer = Balancer( - channels, channel_dim=1, - min_positive=0.4, max_positive=0.6, - min_abs=1.0, max_abs=6.0, - ) - self.out_whiten = Whiten(num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not self.training: - return self.forward_internal(x) - layerdrop_rate = float(self.layerdrop_rate) - - if layerdrop_rate != 0.0: - batch_size = x.shape[0] - mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate - else: - mask = None - # turns out this caching idea does not work with --world-size > 1 - #return caching_eval(self.forward_internal, x, mask) - return self.forward_internal(x, mask) - - - def forward_internal(self, - x: Tensor, - layer_skip_mask: Optional[Tensor] = None) -> Tensor: - """ - x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) - - The returned value has the same shape as x. - """ - bypass = x - x = self.depthwise_conv(x) - x = self.pointwise_conv1(x) - x = self.hidden_balancer(x) - x = self.activation(x) - x = self.pointwise_conv2(x) - - if layer_skip_mask is not None: - x = x * layer_skip_mask - - x = bypass + x - x = self.out_balancer(x) - x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last - x = self.out_whiten(x) - x = x.transpose(1, 3) # (N, C, H, W) - - return x - - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/2 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: FloatLike = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-3)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - bottleneck: - bottleneck dimension for 1d squeeze-excite - """ - assert in_channels >= 7 - super().__init__() - - # The ScaleGrad module is there to prevent the gradients - # w.r.t. the weight or bias of the first Conv2d module in self.conv from - # exceeding the range of fp16 when using automatic mixed precision (amp) - # training. (The second one is necessary to stop its bias from getting - # a too-large gradient). - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ScaleGrad(0.2), - Balancer(layer1_channels, - channel_dim=1, - max_abs=1.0), - SwooshR(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - Balancer(layer2_channels, - channel_dim=1, - max_abs=4.0), - SwooshR(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - Balancer(layer3_channels, - channel_dim=1, - max_abs=4.0), - SwooshR(), - ) - - cur_width = (in_channels - 1) // 2 - - # just one convnext layer - self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) - - out_width = (((in_channels - 1) // 2) - 1) // 2 - - self.out = nn.Linear(out_width * layer3_channels, out_channels) - # use a larger than normal grad_scale on this whitening module; there is - # only one such module, so there is not a concern about adding together - # many copies of this extra gradient term. - self.out_whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(4.0), - prob=(0.025, 0.25), - grad_scale=0.02) - - # max_log_eps=0.0 is to prevent both eps and the output of self.out from - # getting large, there is an unnecessary degree of freedom. - self.out_norm = BiasNorm(out_channels) - self.dropout = Dropout3(dropout, shared_dim=1) - - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - # scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision) - # training, since the weights in the first convolution are otherwise the limiting factor for getting infinite - # gradients. - x = self.conv(x) - x = self.convnext(x) - - # Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - - x = x.transpose(1, 2).reshape(b, t, c * f) - # now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels)) - - x = self.out(x) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_whiten(x) - x = self.out_norm(x) - x = self.dropout(x) - return x - - - def _test_zipformer_main(causal: bool = False): - feature_dim = 50 batch_size = 5 seq_len = 20 - feature_dim = 50 # Just make sure the forward pass runs. c = Zipformer2( - num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4), + encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), causal=causal, chunk_size=(4,) if causal else (-1,), left_context_frames=(64,) @@ -1932,19 +1653,18 @@ def _test_zipformer_main(causal: bool = False): seq_len = 20 # Just make sure the forward pass runs. f = c( - torch.randn(batch_size, seq_len, feature_dim), + torch.randn(seq_len, batch_size, 64), torch.full((batch_size,), seq_len, dtype=torch.int64), ) f[0].sum().backward() c.eval() f = c( - torch.randn(batch_size, seq_len, feature_dim), + torch.randn(seq_len, batch_size, 64), torch.full((batch_size,), seq_len, dtype=torch.int64), ) f # to remove flake8 warnings - if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1)