mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Various bug fixes, implementing chunking
This commit is contained in:
parent
0006a4c4db
commit
3f72813a96
@ -634,9 +634,10 @@ class SubformerEncoder(nn.Module):
|
||||
|
||||
|
||||
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
|
||||
b = src.shape[1] # batch_size
|
||||
|
||||
pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ]
|
||||
attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, c) for c in chunk_sizes ]
|
||||
attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, b, c) for c in chunk_sizes ]
|
||||
# TODO: support this for memory also; would require duplicating it maybe;
|
||||
# or could modify the interior code to just assume chunking
|
||||
# when doing cross-attention.
|
||||
@ -647,7 +648,7 @@ class SubformerEncoder(nn.Module):
|
||||
output = mod(
|
||||
output,
|
||||
pos_embs[ci],
|
||||
attn_offset=attn_offset[ci],
|
||||
attn_offset=attn_offsets[ci],
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
@ -686,16 +687,20 @@ class SubformerEncoder(nn.Module):
|
||||
num_channels = src.shape[-1]
|
||||
return src.reshape(chunk_size, -1, num_channels)
|
||||
|
||||
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, chunk_size: int) -> Tensor:
|
||||
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, batch_size: int, chunk_size: int) -> Tensor:
|
||||
"""
|
||||
Break up attention offset into a given chunk size
|
||||
"""
|
||||
(batch_size, seq_len, seq_len) = attn_offset.shape
|
||||
(_batch_size, seq_len, seq_len) = attn_offset.shape
|
||||
if seq_len == chunk_size:
|
||||
return attn_offset
|
||||
if _batch_size != batch_size:
|
||||
assert _batch_size == 1
|
||||
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
||||
|
||||
assert seq_len % chunk_size == 0
|
||||
|
||||
num_chunks = seq_len / chunk_size
|
||||
num_chunks = seq_len // chunk_size
|
||||
|
||||
batch_stride, tgt_stride, src_stride = attn_offset.stride()
|
||||
|
||||
@ -717,7 +722,7 @@ class SubformerEncoder(nn.Module):
|
||||
return pos_emb
|
||||
assert seq_len % chunk_size == 0
|
||||
|
||||
num_chunks = seq_len / chunk_size
|
||||
num_chunks = seq_len // chunk_size
|
||||
|
||||
batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride()
|
||||
|
||||
|
@ -149,7 +149,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-chunk-size",
|
||||
type=str,
|
||||
default="128"
|
||||
default="128",
|
||||
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
|
||||
"double this value."
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user