diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 8d2d323a2..e2810471a 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -634,9 +634,10 @@ class SubformerEncoder(nn.Module): chunk_sizes, chunk_indexes = self._get_chunk_sizes(src) + b = src.shape[1] # batch_size pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ] - attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, c) for c in chunk_sizes ] + attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, b, c) for c in chunk_sizes ] # TODO: support this for memory also; would require duplicating it maybe; # or could modify the interior code to just assume chunking # when doing cross-attention. @@ -647,7 +648,7 @@ class SubformerEncoder(nn.Module): output = mod( output, pos_embs[ci], - attn_offset=attn_offset[ci], + attn_offset=attn_offsets[ci], memory=memory, memory_key_padding_mask=memory_key_padding_mask, ) @@ -686,16 +687,20 @@ class SubformerEncoder(nn.Module): num_channels = src.shape[-1] return src.reshape(chunk_size, -1, num_channels) - def _attn_offset_to_chunk_size(self, attn_offset: Tensor, chunk_size: int) -> Tensor: + def _attn_offset_to_chunk_size(self, attn_offset: Tensor, batch_size: int, chunk_size: int) -> Tensor: """ Break up attention offset into a given chunk size """ - (batch_size, seq_len, seq_len) = attn_offset.shape + (_batch_size, seq_len, seq_len) = attn_offset.shape if seq_len == chunk_size: return attn_offset + if _batch_size != batch_size: + assert _batch_size == 1 + attn_offset = attn_offset.expand(batch_size, seq_len, seq_len) + assert seq_len % chunk_size == 0 - num_chunks = seq_len / chunk_size + num_chunks = seq_len // chunk_size batch_stride, tgt_stride, src_stride = attn_offset.stride() @@ -717,7 +722,7 @@ class SubformerEncoder(nn.Module): return pos_emb assert seq_len % chunk_size == 0 - num_chunks = seq_len / chunk_size + num_chunks = seq_len // chunk_size batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride() diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index cdeff426d..5c1ae9843 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -149,7 +149,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-chunk-size", type=str, - default="128" + default="128", help="Base chunk size for attention in encoder stacks; alternate layers will use this value or " "double this value." )