mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Various bug fixes, implementing chunking
This commit is contained in:
parent
0006a4c4db
commit
3f72813a96
@ -634,9 +634,10 @@ class SubformerEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
|
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
|
||||||
|
b = src.shape[1] # batch_size
|
||||||
|
|
||||||
pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ]
|
pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ]
|
||||||
attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, c) for c in chunk_sizes ]
|
attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, b, c) for c in chunk_sizes ]
|
||||||
# TODO: support this for memory also; would require duplicating it maybe;
|
# TODO: support this for memory also; would require duplicating it maybe;
|
||||||
# or could modify the interior code to just assume chunking
|
# or could modify the interior code to just assume chunking
|
||||||
# when doing cross-attention.
|
# when doing cross-attention.
|
||||||
@ -647,7 +648,7 @@ class SubformerEncoder(nn.Module):
|
|||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_embs[ci],
|
pos_embs[ci],
|
||||||
attn_offset=attn_offset[ci],
|
attn_offset=attn_offsets[ci],
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
@ -686,16 +687,20 @@ class SubformerEncoder(nn.Module):
|
|||||||
num_channels = src.shape[-1]
|
num_channels = src.shape[-1]
|
||||||
return src.reshape(chunk_size, -1, num_channels)
|
return src.reshape(chunk_size, -1, num_channels)
|
||||||
|
|
||||||
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, chunk_size: int) -> Tensor:
|
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, batch_size: int, chunk_size: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Break up attention offset into a given chunk size
|
Break up attention offset into a given chunk size
|
||||||
"""
|
"""
|
||||||
(batch_size, seq_len, seq_len) = attn_offset.shape
|
(_batch_size, seq_len, seq_len) = attn_offset.shape
|
||||||
if seq_len == chunk_size:
|
if seq_len == chunk_size:
|
||||||
return attn_offset
|
return attn_offset
|
||||||
|
if _batch_size != batch_size:
|
||||||
|
assert _batch_size == 1
|
||||||
|
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
||||||
|
|
||||||
assert seq_len % chunk_size == 0
|
assert seq_len % chunk_size == 0
|
||||||
|
|
||||||
num_chunks = seq_len / chunk_size
|
num_chunks = seq_len // chunk_size
|
||||||
|
|
||||||
batch_stride, tgt_stride, src_stride = attn_offset.stride()
|
batch_stride, tgt_stride, src_stride = attn_offset.stride()
|
||||||
|
|
||||||
@ -717,7 +722,7 @@ class SubformerEncoder(nn.Module):
|
|||||||
return pos_emb
|
return pos_emb
|
||||||
assert seq_len % chunk_size == 0
|
assert seq_len % chunk_size == 0
|
||||||
|
|
||||||
num_chunks = seq_len / chunk_size
|
num_chunks = seq_len // chunk_size
|
||||||
|
|
||||||
batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride()
|
batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride()
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-chunk-size",
|
"--encoder-chunk-size",
|
||||||
type=str,
|
type=str,
|
||||||
default="128"
|
default="128",
|
||||||
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
|
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
|
||||||
"double this value."
|
"double this value."
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user