Implement chunk sizes, to the extent that the program runs.

This commit is contained in:
Daniel Povey 2023-05-16 16:13:20 +08:00
parent 4562b25a6a
commit 0006a4c4db
2 changed files with 109 additions and 11 deletions

View File

@ -81,6 +81,7 @@ class Subformer(EncoderInterface):
def __init__( def __init__(
self, self,
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384), encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
encoder_chunk_size: Union[int, Tuple[int]] = 128,
num_encoder_layers: Union[int, Tuple[int]] = 4, num_encoder_layers: Union[int, Tuple[int]] = 4,
query_head_dim: Union[int, Tuple[int]] = 24, query_head_dim: Union[int, Tuple[int]] = 24,
value_head_dim: Union[int, Tuple[int]] = 12, value_head_dim: Union[int, Tuple[int]] = 12,
@ -110,6 +111,7 @@ class Subformer(EncoderInterface):
return x return x
self.encoder_dim = encoder_dim self.encoder_dim = encoder_dim
encoder_chunk_size = _to_tuple(encoder_chunk_size)
num_encoder_layers = _to_tuple(num_encoder_layers) num_encoder_layers = _to_tuple(num_encoder_layers)
query_head_dim = _to_tuple(query_head_dim) query_head_dim = _to_tuple(query_head_dim)
value_head_dim = _to_tuple(value_head_dim) value_head_dim = _to_tuple(value_head_dim)
@ -148,6 +150,7 @@ class Subformer(EncoderInterface):
encoder_layer, encoder_layer,
num_encoder_layers[i], num_encoder_layers[i],
dropout=dropout, dropout=dropout,
chunk_size=encoder_chunk_size[i],
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
@ -460,7 +463,8 @@ class SubformerEncoderLayer(nn.Module):
Pass the input through the encoder layer. Pass the input through the encoder layer.
Args: Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) pos_emb: (batch_size, seq_len, seq_len, pos_dim), with e.g. pos_dim=4: relatie positional
embedding tensor.
feature_mask: something that broadcasts with src, that we'll multiply `src` feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len), attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
@ -565,11 +569,14 @@ class SubformerEncoder(nn.Module):
dropout: float, dropout: float,
warmup_begin: float, warmup_begin: float,
warmup_end: float, warmup_end: float,
chunk_size: int = 256,
initial_layerdrop_rate: float = 0.5, initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05, final_layerdrop_rate: float = 0.05,
) -> None: ) -> None:
super().__init__() super().__init__()
self.chunk_size = chunk_size
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
@ -595,7 +602,6 @@ class SubformerEncoder(nn.Module):
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
feature_mask: Optional[Tensor] = None,
attn_offset: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
@ -623,23 +629,107 @@ class SubformerEncoder(nn.Module):
rnd_seed = src.numel() + random.randint(0, 1000) rnd_seed = src.numel() + random.randint(0, 1000)
if feature_mask is not None: #if feature_mask is not None:
output = output * feature_mask # output = output * feature_mask
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ]
attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, c) for c in chunk_sizes ]
# TODO: support this for memory also; would require duplicating it maybe;
# or could modify the interior code to just assume chunking
# when doing cross-attention.
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
ci = chunk_indexes[i]
c = chunk_sizes[ci]
output = self._to_chunk_size(output, c)
output = mod( output = mod(
output, output,
pos_emb, pos_embs[ci],
attn_offset=attn_offset, attn_offset=attn_offset[ci],
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
if feature_mask is not None: #if feature_mask is not None:
output = output * feature_mask # output = output * feature_mask
output = self._to_chunk_size(output, src.shape[0])
return self.bypass(src, output) return self.bypass(src, output)
def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]:
"""
Decide the chunk sizes (in frames) to use for each layer.
Args:
src: the input embeddings, of shape (seq_len, batch_size, embed_dim)
Returns: (chunk_sizes, chunk_indexes), where:
chunk_sizes: a list of the unique chunk sizes to use, e.g. [ 128, 256 ]
chunk_indexes: a list of indexes into chunk_sizes, one per layer.
"""
seq_len = src.shape[0]
assert seq_len < self.chunk_size or seq_len % self.chunk_size == 0
if seq_len <= self.chunk_size:
return [ seq_len ], [ 0 ] * len(self.layers)
else:
assert seq_len % self.chunk_size == 0, (seq_len, self.chunk_size)
num_layers = len(self.layers)
chunk_indexes = [0, 1] * (num_layers + 1 // 2)
return [ self.chunk_size, self.chunk_size * 2 ], chunk_indexes[:num_layers]
def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor:
"""
Reshape embeddings 'src' to have a different chunk size (in frames)
"""
num_channels = src.shape[-1]
return src.reshape(chunk_size, -1, num_channels)
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, chunk_size: int) -> Tensor:
"""
Break up attention offset into a given chunk size
"""
(batch_size, seq_len, seq_len) = attn_offset.shape
if seq_len == chunk_size:
return attn_offset
assert seq_len % chunk_size == 0
num_chunks = seq_len / chunk_size
batch_stride, tgt_stride, src_stride = attn_offset.stride()
# have the 'chunk' dimension first so it has larger stride than the original batch; this
# is to match what happens to the embeddings in 'src' where the time-stride is first.
attn_offset = attn_offset.as_strided((num_chunks, batch_size, chunk_size, chunk_size),
((tgt_stride + src_stride) * chunk_size, batch_stride,
tgt_stride, src_stride))
return attn_offset.contiguous().reshape(num_chunks * batch_size, chunk_size, chunk_size)
def _pos_emb_to_chunk_size(self, pos_emb: Tensor, chunk_size: int) -> Tensor:
"""
Break up positional embedding tensor into a given chunk size
"""
(batch_size, seq_len, seq_len, pos_dim) = pos_emb.shape
if seq_len == chunk_size:
return pos_emb
assert seq_len % chunk_size == 0
num_chunks = seq_len / chunk_size
batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride()
pos_emb = pos_emb.as_strided((num_chunks, batch_size, chunk_size, chunk_size, pos_dim),
((tgt_stride + src_stride) * chunk_size, batch_stride,
tgt_stride, src_stride, channel_stride))
return pos_emb.contiguous().reshape(num_chunks * batch_size,
chunk_size, chunk_size,
pos_dim)
class BypassModule(nn.Module): class BypassModule(nn.Module):
""" """
@ -1137,11 +1227,10 @@ class CompactRelPositionalEncoding(torch.nn.Module):
"""Create positional encoding. """Create positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (time, batch, num_channels_in) x (torch.Tensor): Input tensor (seq_len, batch_size, num_channels_in)
Returns: Returns:
positional embedding, of shape (batch_size, 2*time-1, pos_dim). positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim).
""" """
self.extend_pe(x) self.extend_pe(x)
seq_len = x.size(0) seq_len = x.size(0)

View File

@ -146,6 +146,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Embedding dimension in encoder stacks: a single int or comma-separated list." help="Embedding dimension in encoder stacks: a single int or comma-separated list."
) )
parser.add_argument(
"--encoder-chunk-size",
type=str,
default="128"
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
"double this value."
)
parser.add_argument( parser.add_argument(
"--query-head-dim", "--query-head-dim",
type=str, type=str,
@ -415,6 +423,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Subformer( encoder = Subformer(
num_encoder_layers=_to_int_tuple(params.num_encoder_layers), num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim), encoder_dim=_to_int_tuple(params.encoder_dim),
encoder_chunk_size=_to_int_tuple(params.encoder_chunk_size),
query_head_dim=_to_int_tuple(params.query_head_dim), query_head_dim=_to_int_tuple(params.query_head_dim),
pos_dim=int(params.pos_dim), pos_dim=int(params.pos_dim),
value_head_dim=_to_int_tuple(params.value_head_dim), value_head_dim=_to_int_tuple(params.value_head_dim),