diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index ee5358324..8d2d323a2 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -81,6 +81,7 @@ class Subformer(EncoderInterface): def __init__( self, encoder_dim: Union[int, Tuple[int]] = (384, 512, 384), + encoder_chunk_size: Union[int, Tuple[int]] = 128, num_encoder_layers: Union[int, Tuple[int]] = 4, query_head_dim: Union[int, Tuple[int]] = 24, value_head_dim: Union[int, Tuple[int]] = 12, @@ -110,6 +111,7 @@ class Subformer(EncoderInterface): return x self.encoder_dim = encoder_dim + encoder_chunk_size = _to_tuple(encoder_chunk_size) num_encoder_layers = _to_tuple(num_encoder_layers) query_head_dim = _to_tuple(query_head_dim) value_head_dim = _to_tuple(value_head_dim) @@ -148,6 +150,7 @@ class Subformer(EncoderInterface): encoder_layer, num_encoder_layers[i], dropout=dropout, + chunk_size=encoder_chunk_size[i], warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), @@ -460,7 +463,8 @@ class SubformerEncoderLayer(nn.Module): Pass the input through the encoder layer. Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) + pos_emb: (batch_size, seq_len, seq_len, pos_dim), with e.g. pos_dim=4: relatie positional + embedding tensor. feature_mask: something that broadcasts with src, that we'll multiply `src` by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len), @@ -565,11 +569,14 @@ class SubformerEncoder(nn.Module): dropout: float, warmup_begin: float, warmup_end: float, + chunk_size: int = 256, initial_layerdrop_rate: float = 0.5, final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() + self.chunk_size = chunk_size + self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) @@ -595,7 +602,6 @@ class SubformerEncoder(nn.Module): self, src: Tensor, pos_emb: Tensor, - feature_mask: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None, memory: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, @@ -623,23 +629,107 @@ class SubformerEncoder(nn.Module): rnd_seed = src.numel() + random.randint(0, 1000) - if feature_mask is not None: - output = output * feature_mask + #if feature_mask is not None: + # output = output * feature_mask + + chunk_sizes, chunk_indexes = self._get_chunk_sizes(src) + + pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ] + attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, c) for c in chunk_sizes ] + # TODO: support this for memory also; would require duplicating it maybe; + # or could modify the interior code to just assume chunking + # when doing cross-attention. for i, mod in enumerate(self.layers): + ci = chunk_indexes[i] + c = chunk_sizes[ci] + output = self._to_chunk_size(output, c) output = mod( output, - pos_emb, - attn_offset=attn_offset, + pos_embs[ci], + attn_offset=attn_offset[ci], memory=memory, memory_key_padding_mask=memory_key_padding_mask, ) - if feature_mask is not None: - output = output * feature_mask + #if feature_mask is not None: + # output = output * feature_mask + + output = self._to_chunk_size(output, src.shape[0]) return self.bypass(src, output) + def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]: + """ + Decide the chunk sizes (in frames) to use for each layer. + Args: + src: the input embeddings, of shape (seq_len, batch_size, embed_dim) + Returns: (chunk_sizes, chunk_indexes), where: + chunk_sizes: a list of the unique chunk sizes to use, e.g. [ 128, 256 ] + chunk_indexes: a list of indexes into chunk_sizes, one per layer. + """ + seq_len = src.shape[0] + assert seq_len < self.chunk_size or seq_len % self.chunk_size == 0 + if seq_len <= self.chunk_size: + return [ seq_len ], [ 0 ] * len(self.layers) + else: + assert seq_len % self.chunk_size == 0, (seq_len, self.chunk_size) + num_layers = len(self.layers) + chunk_indexes = [0, 1] * (num_layers + 1 // 2) + return [ self.chunk_size, self.chunk_size * 2 ], chunk_indexes[:num_layers] + + + def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor: + """ + Reshape embeddings 'src' to have a different chunk size (in frames) + """ + num_channels = src.shape[-1] + return src.reshape(chunk_size, -1, num_channels) + + def _attn_offset_to_chunk_size(self, attn_offset: Tensor, chunk_size: int) -> Tensor: + """ + Break up attention offset into a given chunk size + """ + (batch_size, seq_len, seq_len) = attn_offset.shape + if seq_len == chunk_size: + return attn_offset + assert seq_len % chunk_size == 0 + + num_chunks = seq_len / chunk_size + + batch_stride, tgt_stride, src_stride = attn_offset.stride() + + # have the 'chunk' dimension first so it has larger stride than the original batch; this + # is to match what happens to the embeddings in 'src' where the time-stride is first. + attn_offset = attn_offset.as_strided((num_chunks, batch_size, chunk_size, chunk_size), + ((tgt_stride + src_stride) * chunk_size, batch_stride, + tgt_stride, src_stride)) + + return attn_offset.contiguous().reshape(num_chunks * batch_size, chunk_size, chunk_size) + + + def _pos_emb_to_chunk_size(self, pos_emb: Tensor, chunk_size: int) -> Tensor: + """ + Break up positional embedding tensor into a given chunk size + """ + (batch_size, seq_len, seq_len, pos_dim) = pos_emb.shape + if seq_len == chunk_size: + return pos_emb + assert seq_len % chunk_size == 0 + + num_chunks = seq_len / chunk_size + + batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride() + + pos_emb = pos_emb.as_strided((num_chunks, batch_size, chunk_size, chunk_size, pos_dim), + ((tgt_stride + src_stride) * chunk_size, batch_stride, + tgt_stride, src_stride, channel_stride)) + + return pos_emb.contiguous().reshape(num_chunks * batch_size, + chunk_size, chunk_size, + pos_dim) + + class BypassModule(nn.Module): """ @@ -1137,11 +1227,10 @@ class CompactRelPositionalEncoding(torch.nn.Module): """Create positional encoding. Args: - x (torch.Tensor): Input tensor (time, batch, num_channels_in) + x (torch.Tensor): Input tensor (seq_len, batch_size, num_channels_in) Returns: - positional embedding, of shape (batch_size, 2*time-1, pos_dim). - + positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim). """ self.extend_pe(x) seq_len = x.size(0) diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 32535c438..cdeff426d 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -146,6 +146,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Embedding dimension in encoder stacks: a single int or comma-separated list." ) + parser.add_argument( + "--encoder-chunk-size", + type=str, + default="128" + help="Base chunk size for attention in encoder stacks; alternate layers will use this value or " + "double this value." + ) + parser.add_argument( "--query-head-dim", type=str, @@ -415,6 +423,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Subformer( num_encoder_layers=_to_int_tuple(params.num_encoder_layers), encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_chunk_size=_to_int_tuple(params.encoder_chunk_size), query_head_dim=_to_int_tuple(params.query_head_dim), pos_dim=int(params.pos_dim), value_head_dim=_to_int_tuple(params.value_head_dim),