Various bug fixes, implementing chunking

This commit is contained in:
Daniel Povey 2023-05-16 16:27:09 +08:00
parent 0006a4c4db
commit 3f72813a96
2 changed files with 12 additions and 7 deletions

View File

@ -634,9 +634,10 @@ class SubformerEncoder(nn.Module):
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
b = src.shape[1] # batch_size
pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ]
attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, c) for c in chunk_sizes ]
attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, b, c) for c in chunk_sizes ]
# TODO: support this for memory also; would require duplicating it maybe;
# or could modify the interior code to just assume chunking
# when doing cross-attention.
@ -647,7 +648,7 @@ class SubformerEncoder(nn.Module):
output = mod(
output,
pos_embs[ci],
attn_offset=attn_offset[ci],
attn_offset=attn_offsets[ci],
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
@ -686,16 +687,20 @@ class SubformerEncoder(nn.Module):
num_channels = src.shape[-1]
return src.reshape(chunk_size, -1, num_channels)
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, chunk_size: int) -> Tensor:
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, batch_size: int, chunk_size: int) -> Tensor:
"""
Break up attention offset into a given chunk size
"""
(batch_size, seq_len, seq_len) = attn_offset.shape
(_batch_size, seq_len, seq_len) = attn_offset.shape
if seq_len == chunk_size:
return attn_offset
if _batch_size != batch_size:
assert _batch_size == 1
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
assert seq_len % chunk_size == 0
num_chunks = seq_len / chunk_size
num_chunks = seq_len // chunk_size
batch_stride, tgt_stride, src_stride = attn_offset.stride()
@ -717,7 +722,7 @@ class SubformerEncoder(nn.Module):
return pos_emb
assert seq_len % chunk_size == 0
num_chunks = seq_len / chunk_size
num_chunks = seq_len // chunk_size
batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride()

View File

@ -149,7 +149,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-chunk-size",
type=str,
default="128"
default="128",
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
"double this value."
)