From e7e7560bba5fc4e0d03df7dd5fdb69304f0ff9c5 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 10 Feb 2023 14:53:47 +0800 Subject: [PATCH] Implement chunking --- .../ASR/pruned_transducer_stateless7/model.py | 8 +- .../pruned_transducer_stateless7/scaling.py | 131 +++++++++- .../ASR/pruned_transducer_stateless7/train.py | 43 ++++ .../pruned_transducer_stateless7/zipformer.py | 225 +++++++++++------- 4 files changed, 324 insertions(+), 83 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 8f707cf4f..7197ace17 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -84,6 +84,8 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + chunk_size: int = -1, + left_context_chunks: int = -1, ) -> torch.Tensor: """ Args: @@ -104,6 +106,9 @@ class Transducer(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + chunk_size, left_context_chunks: + For chunkwise causal training; will be passed to the zipformer encoder. + chunk_size is specified in frames at 50Hz, i.e. after 2x downsampling. Returns: Return the transducer loss. @@ -119,7 +124,8 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) + encoder_out, x_lens = self.encoder(x, x_lens, chunk_size=chunk_size, + left_context_chunks=left_context_chunks) assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index be7a4abd6..3e1470fbf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1014,12 +1014,13 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs ) -> nn.Conv2d: """ - Behaves like a constructor of a modified version of nn.Conv1d + Behaves like a constructor of a modified version of nn.Conv2d that gives an easy way to set the default initial parameter scale. Args: Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. + e.g. in_features, out_features, bias=False, but: + NO PADDING-RELATED ARGS. initial_scale: you can override this if you want to increase or decrease the initial magnitude of the module's output @@ -1037,6 +1038,132 @@ def ScaledConv2d(*args, return ans +class ChunkCausalDepthwiseConv1d(torch.nn.Module): + """ + Behaves like a depthwise 1d convolution, except that it is causal in + a chunkwise way, as if we had a block-triangular attention mask. + The chunk size is provided at test time (it should probably be + kept in sync with the attention mask). + + This has a little more than twice the parameters of a conventional + depthwise conv1d module: we implement it by having one + depthwise convolution, of half the width, that is causal (via + right-padding); and one depthwise convolution that is applied only + within chunks, that we multiply by a scaling factor which depends + on the position within the chunk. + + Args: + Accepts the standard args and kwargs that nn.Linear accepts + e.g. in_features, out_features, bias=False. + + initial_scale: you can override this if you want to increase + or decrease the initial magnitude of the module's output + (affects the initialization of weight_scale and bias_scale). + Another option, if you want to do something like this, is + to re-initialize the parameters. + """ + def __init__(self, + channels: int, + kernel_size: int, + initial_scale: float = 1.0, + bias: bool = True): + super().__init__() + assert kernel_size % 2 == 1 + + half_kernel_size = (kernel_size + 1) // 2 + # will pad manually, on one side. + self.causal_conv = nn.Conv1d(in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=half_kernel_size, + padding=0, + bias=True) + + self.chunkwise_conv = nn.Conv1d(in_channels=channels, + out_channels=channels, + groups=channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + bias=bias) + + # first row is correction factors added to the scale near the left edge of the chunk, + # second row is correction factors added to the scale near the right edge of the chunk, + # both of these are added to a default scale of 1.0. + self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.kernel_size = kernel_size + + with torch.no_grad(): + self.causal_conv.weight[:] *= initial_scale + self.chunkwise_conv.weight[:] *= initial_scale + if bias: + torch.nn.init.uniform_(self.causal_conv.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) + + + def forward(self, + x: Tensor, + chunk_size: int = -1) -> Tensor: + """ + Forward function. Args: + x: a Tensor of shape (batch_size, channels, seq_len) + chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. + """ + (batch_size, num_channels, seq_len) = x.shape + + half_kernel_size = self.kernel_size + 1 // 2 + # left_pad is half_kernel_size - 1 where half_kernel_size is the size used + # in the causal conv. It's the amount by which we must pad on the left, + # to make the convolution causal. + left_pad = self.kernel_size // 2 + + if chunk_size < 0: + chunk_size = seq_len + right_pad = -seq_len % chunk_size + + x = torch.nn.functional.pad(x, (left_pad, right_pad)) + + x_causal = self.causal_conv(x[..., :seq_len + left_pad]) + assert x_causal.shape == (batch_size, num_channels, seq_len) + + x_chunk = x[..., left_pad:] + num_chunks = x_chunk.shape[2] // chunk_size + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks, + num_channels, chunk_size) + x_chunk = self.chunkwise_conv(x_chunk) # does not change shape + + chunk_scale = self._get_chunk_scale(chunk_size) + + x_chunk = x_chunk * chunk_scale + x_chunk = x_chunk.reshape(batch_size, num_chunks, + num_channels, chunk_size).permute(0, 2, 1, 3) + x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len] + + return x_chunk + x_causal + + def _get_chunk_scale(self, chunk_size: int): + """Returns tensor of shape (num_channels, chunk_size) that will be used to + scale the output of self.chunkwise_conv.""" + left_edge = self.chunkwise_conv_scale[0] + right_edge = self.chunkwise_conv_scale[1] + if chunk_size < self.kernel_size: + left_edge = left_edge[:, :chunk_size] + right_edge = right_edge[:, -chunk_size:] + else: + t = chunk_size - self.kernel_size + channels = left_edge.shape[0] + pad = torch.zeros(channels, t, + device=left_edge.device, + dtype=left_edge.dtype) + left_edge = torch.cat((left_edge, pad), dim=-1) + right_edge = torch.cat((pad, right_edge), dim=-1) + return 1.0 + (left_edge + right_edge) + + + + + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b98ec1c0e..2e5b2a641 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -47,6 +47,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import copy import logging +import random import warnings from pathlib import Path from shutil import copyfile @@ -225,6 +226,23 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--chunk-size", + type=str, + default="-1", + help=" Embedding dimension in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--chunk-left-context-frames", + type=str, + default="64,128,256,-1", + help="Left-contexts for chunkwise training, measured in frames (positive values must be " + "multiples of all positive elements of chunk-size). If --chunk-size is specified, " + "chunk left-context frames will be chosen randomly from this list." + ) + + def get_parser(): parser = argparse.ArgumentParser( @@ -526,6 +544,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, + causal=(params.chunk_size != "-1"), ) return encoder @@ -686,6 +705,26 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +def get_chunk_info(params: AttributeDict) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + chunk_sizes = list(map(int, params.chunk_size.split(','))) + n = len(chunk_sizes) + chunk_size = random.choice(chunk_sizes) + if chunk_size == -1: + left_context_chunks = -1 + else: + chunk_left_context_frames = list(map(int, params.chunk_left_context_frames.split(','))) + m = len(chunk_left_context_frames) + left_context_frames = random.choice(chunk_left_context_frames) + if left_context_frames != -1: + assert left_context_frames % chunk_size == 0, "Invalid --chunk-left-context-frames value" + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + return chunk_size, left_context_chunks + + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -731,6 +770,8 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) + chunk_size, left_context_chunks = get_chunk_info(params) + with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, @@ -739,6 +780,8 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + chunk_size=chunk_size, + left_context_chunks=left_context_chunks, ) s = params.simple_loss_scale diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 0ef8c925e..089b918e0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -37,6 +37,7 @@ from scaling import ( SwooshL, SwooshR, TanSwish, + ChunkCausalDepthwiseConv1d, ScaledConv1d, ScaledConv2d, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. @@ -96,8 +97,12 @@ class Zipformer(EncoderInterface): dropout (float): dropout rate warmup_batches (float): number of batches to warm up over; this controls dropout of encoder layers. + causal (bool): if True, support chunkwise causal convolution. This should + not hurt WER as no modeling power is lost, but the convolution modules will be + slightly slower and use more memory. Enables use of the chunk_size and + left_context_chunk options in forward(), which simulates streaming + decoding. """ - def __init__( self, num_features: int, @@ -116,6 +121,7 @@ class Zipformer(EncoderInterface): pos_dim: int = 192, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, + causal: bool = False, ) -> None: super(Zipformer, self).__init__() @@ -144,6 +150,7 @@ class Zipformer(EncoderInterface): self.num_features = num_features # int self.output_downsampling_factor = output_downsampling_factor # int self.downsampling_factor = downsampling_factor # tuple + self.downsampling_factor_gcd = next(n for n in range(1, 10000) if all(n % d == 0 for d in downsampling_factor)) self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple num_encoder_layers = _to_tuple(num_encoder_layers) @@ -153,8 +160,7 @@ class Zipformer(EncoderInterface): num_heads = _to_tuple(num_heads) attention_share_layers = _to_tuple(attention_share_layers) feedforward_dim = _to_tuple(feedforward_dim) - cnn_module_kernel = _to_tuple(cnn_module_kernel) - + self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) for u,d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d @@ -187,6 +193,7 @@ class Zipformer(EncoderInterface): feedforward_dim=feedforward_dim[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], + causal=causal, ) # For the segment of the warmup period, we let the Conv2dSubsampling @@ -314,6 +321,8 @@ class Zipformer(EncoderInterface): def forward( self, x: torch.Tensor, x_lens: torch.Tensor, + chunk_size: int = -1, + left_context_chunks: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -322,6 +331,14 @@ class Zipformer(EncoderInterface): x_lens: A tensor of shape (batch_size,) containing the number of frames in `x` before padding. + chunk_size: Number of frames per chunk (only set this if causal == True). + Must divide all elements of downsampling_factor. At 50hz frame + rate, i.e. after encoder_embed. If not specified, no chunking. + left_context_chunks: Number of left-context chunks for each chunk (affects + attention mask); only set this if chunk_size specified. If -1, there + is no limit on the left context. If not -1, require: + left_context_chunks * context_size >= downsampling_factor[i] * + cnn_module_kernel[i] // 2. Returns: Return a tuple containing 2 tensors: - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) @@ -340,11 +357,13 @@ class Zipformer(EncoderInterface): warnings.simplefilter("ignore") lengths = (x_lens - 7) // 2 assert x.size(0) == lengths.max().item() - mask = make_pad_mask(lengths) + src_key_padding_mask = make_pad_mask(lengths) outputs = [] feature_masks = self.get_feature_masks(x) + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] if self.skip_layers[i] is not None: @@ -361,8 +380,12 @@ class Zipformer(EncoderInterface): else: x = skip_x x = module(x, + chunk_size=chunk_size, feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[...,::ds]) + src_key_padding_mask=(None if src_key_padding_mask is None + else src_key_padding_mask[...,::ds]), + attn_mask=attn_mask, + ) outputs.append(x) def get_full_dim_output(): @@ -395,6 +418,42 @@ class Zipformer(EncoderInterface): return x, lengths + def _get_attn_mask(self, x: Tensor, + chunk_size: int, + left_context_chunks: int + ) -> Optional[Tensor]: + """ + Return None if chunk_size == -1, else return attention mask of shape + (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True + means a masked position. + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + chunk_size: chunk size, must divide + """ + if chunk_size <= 0: + return None + assert all(chunk_size % d == 0 for d in self.downsampling_factor) + if left_context_chunks >= 0: + num_encoders = len(self.encoder_dim) + assert all (chunk_size * left_context_chunks >= + (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + for i in range(num_encoders)) + else: + left_context_chunks = 1000000 + + seq_len = x.shape[0] + + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32) + # c is chunk index for each frame, shape (seq_len,) + c = t // chunk_size + src_c = c + tgt_c = c.unsqueeze(-1) + + attn_mask = torch.logical_or(src_c > tgt_c, + src_c < tgt_c - left_context_chunks) + if __name__ == "__main__": + logging.info(f"attn_mask = {attn_mask}") @@ -434,6 +493,7 @@ class ZipformerEncoderLayer(nn.Module): feedforward_dim: int, dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, + causal: bool = False, # layer_skip_rate will be overwritten to change warmup begin and end times. # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # to work correctly. @@ -487,7 +547,8 @@ class ZipformerEncoderLayer(nn.Module): hidden_channels=3 * embed_dim // 4) self.conv_module = ConvolutionModule(embed_dim, - cnn_module_kernel) + cnn_module_kernel, + causal=causal) #self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) @@ -566,27 +627,24 @@ class ZipformerEncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, - src_mask: Optional[Tensor] = None, + chunk_size: int = -1, + attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, attn_weights: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Pass the input through the encoder layer. - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - attn_weights: possibly attention weights computed by the previous layer, - to be used if self.self_attn_weights is None - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. Returns: (x, attn_weights) where x has the same shape as src, and attn_weights are of @@ -602,7 +660,7 @@ class ZipformerEncoderLayer(nn.Module): attn_weights = self.self_attn_weights( src, pos_emb=pos_emb, - attn_mask=src_mask, + attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, ) # else rely on the ones passed in @@ -642,7 +700,8 @@ class ZipformerEncoderLayer(nn.Module): src, attn_weights) if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate): - src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + src = src + self.conv_module(src, chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask) if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate): src = src + self.balancer_ff2(self.feed_forward2(src)) @@ -660,7 +719,6 @@ class ZipformerEncoderLayer(nn.Module): return src, attn_weights - class ZipformerEncoder(nn.Module): r"""ZipformerEncoder is a stack of N encoder layers @@ -713,32 +771,29 @@ class ZipformerEncoder(nn.Module): def forward( self, src: Tensor, + chunk_size: int = -1, feature_mask: Union[Tensor, float] = 1.0, - mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. Args: - src: the sequence to the encoder (required). + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: (x, x_no_combine), both of shape (S, N, E) + Returns: a Tensor with the same shape as src. """ pos_emb = self.encoder_pos(src) output = src - rnd_seed = src.numel() + random.randint(0, 1000) output = output * feature_mask @@ -749,7 +804,8 @@ class ZipformerEncoder(nn.Module): output, attn_weights = mod( output, pos_emb, - src_mask=mask, + chunk_size=chunk_size, + attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, attn_weights=attn_weights, ) @@ -774,7 +830,7 @@ class DownsampledZipformerEncoder(nn.Module): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = SimpleDownsample(input_dim, output_dim, - downsample, dropout) + downsample, dropout) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) self.out_combiner = SimpleCombiner(input_dim, @@ -784,39 +840,37 @@ class DownsampledZipformerEncoder(nn.Module): def forward(self, src: Tensor, + chunk_size: int = -1, feature_mask: Union[Tensor, float] = 1.0, - mask: Optional[Tensor] = None, + attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: r"""Downsample, go through encoder, upsample. Args: - src: the sequence to the encoder (required). + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. feature_mask is expected to be already downsampled by - self.downsample_factor. - mask: the mask for the src sequence (optional). CAUTION: we need to downsample - this, if we are to support it. Won't work correctly yet. - src_key_padding_mask: the mask for the src keys per batch (optional). Should - be downsampled already. + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). + True means masked position. May be None. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. - Shape: - src: (S, N, E). - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) + Returns: a Tensor with the same shape as src. """ src_orig = src src = self.downsample(src) ds = self.downsample_factor - if mask is not None: - mask = mask[::ds,::ds] + if attn_mask is not None: + attn_mask = attn_mask[::ds,::ds] src = self.encoder( - src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, + src, + chunk_size=chunk_size // ds, + feature_mask=feature_mask, + attn_mask=attn_mask, + src_key_padding_mask=src_key_padding_mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor @@ -990,8 +1044,9 @@ class SmallConvolutionModule(nn.Module): ) -> None: super().__init__() - - self.depthwise_conv = nn.Conv1d( + self.depthwise_conv = ChunkCausalDepthwiseConv1d( + channels=channels, + kernel_size=kernel_size) if causal else nn.Conv1d( in_channels=channels, out_channels=channels, groups=channels, @@ -1139,13 +1194,13 @@ class CompactRelPositionalEncoding(torch.nn.Module): def forward(self, x: torch.Tensor) -> Tensor: - """Add positional encoding. + """Create positional encoding. Args: x (torch.Tensor): Input tensor (time, batch, `*`). Returns: - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + positional embedding, of shape (1, 2*time-1, `*`). """ self.extend_pe(x) @@ -1235,6 +1290,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self, x: Tensor, pos_emb: Tensor, + chunk_size: int = -1, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, ) -> Tensor: @@ -1242,6 +1298,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): Args: x: input of shape (seq_len, batch_size, embed_dim) pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim) + chunk_size key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that are True in this mask will be ignored as sources in the attention weighting. attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len), @@ -1687,9 +1744,8 @@ class ConvolutionModule(nn.Module): bias (bool): Whether to use bias in conv layers (default=True). """ - def __init__( - self, channels: int, kernel_size: int, + self, channels: int, kernel_size: int, causal: bool, ) -> None: """Construct a ConvolutionModule object.""" super(ConvolutionModule, self).__init__() @@ -1697,7 +1753,7 @@ class ConvolutionModule(nn.Module): assert (kernel_size - 1) % 2 == 0 bottleneck_dim = channels - + self.causal = causal self.in_proj = nn.Linear( channels, 2 * bottleneck_dim, @@ -1706,7 +1762,6 @@ class ConvolutionModule(nn.Module): # sigmoid in glu. self.in_proj.lr_scale = 0.9 - # after in_proj we put x through a gated linear unit (nn.functional.glu). # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, # but sometimes, for some reason, for layer 0 the rms ends up being very large, @@ -1734,15 +1789,17 @@ class ConvolutionModule(nn.Module): self.activation2 = Identity() # for diagnostics - self.depthwise_conv = nn.Conv1d( - bottleneck_dim, - bottleneck_dim, - kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, + assert kernel_size % 2 == 1 + + self.depthwise_conv = ChunkCausalDepthwiseConv1d( + channels=bottleneck_dim, + kernel_size=kernel_size) if causal else nn.Conv1d( + in_channels=bottleneck_dim, + out_channels=bottleneck_dim, groups=bottleneck_dim, - bias=True, - ) + kernel_size=kernel_size, + padding=kernel_size // 2) + self.balancer2 = Balancer( bottleneck_dim, channel_dim=1, @@ -1768,6 +1825,7 @@ class ConvolutionModule(nn.Module): def forward(self, x: Tensor, src_key_padding_mask: Optional[Tensor] = None, + chunk_size: int = -1, ) -> Tensor: """Compute convolution module. @@ -1798,8 +1856,11 @@ class ConvolutionModule(nn.Module): if src_key_padding_mask is not None: x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - # 1D Depthwise Conv - x = self.depthwise_conv(x) + if chunk_size >= 0: + assert self.causal, "Must initialize model with causal=True if you use chunk_size" + x = self.depthwise_conv(x, chunk_size=chunk_size) + else: + x = self.depthwise_conv(x) x = self.balancer2(x) x = x.permute(2, 0, 1) # (time, batch, channels) @@ -2186,7 +2247,7 @@ def _test_random_combine(): assert torch.allclose(y, x[0]) # .. since actually all ones. -def _test_zipformer_main(): +def _test_zipformer_main(causal: bool = False): feature_dim = 50 batch_size = 5 seq_len = 20 @@ -2194,7 +2255,8 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4) + num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4), + causal=causal, ) batch_size = 5 seq_len = 20 @@ -2202,6 +2264,7 @@ def _test_zipformer_main(): f = c( torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), + chunk_size=4 if causal else -1, ) f[0].sum().backward() c.eval() @@ -2212,9 +2275,11 @@ def _test_zipformer_main(): f # to remove flake8 warnings + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_random_combine() - _test_zipformer_main() + _test_zipformer_main(False) + _test_zipformer_main(True)