mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement chunk sizes, to the extent that the program runs.
This commit is contained in:
parent
4562b25a6a
commit
0006a4c4db
@ -81,6 +81,7 @@ class Subformer(EncoderInterface):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
|
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
|
||||||
|
encoder_chunk_size: Union[int, Tuple[int]] = 128,
|
||||||
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
||||||
query_head_dim: Union[int, Tuple[int]] = 24,
|
query_head_dim: Union[int, Tuple[int]] = 24,
|
||||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||||
@ -110,6 +111,7 @@ class Subformer(EncoderInterface):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
self.encoder_dim = encoder_dim
|
self.encoder_dim = encoder_dim
|
||||||
|
encoder_chunk_size = _to_tuple(encoder_chunk_size)
|
||||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||||
query_head_dim = _to_tuple(query_head_dim)
|
query_head_dim = _to_tuple(query_head_dim)
|
||||||
value_head_dim = _to_tuple(value_head_dim)
|
value_head_dim = _to_tuple(value_head_dim)
|
||||||
@ -148,6 +150,7 @@ class Subformer(EncoderInterface):
|
|||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers[i],
|
num_encoder_layers[i],
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
|
chunk_size=encoder_chunk_size[i],
|
||||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||||
@ -460,7 +463,8 @@ class SubformerEncoderLayer(nn.Module):
|
|||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
pos_emb: (batch_size, seq_len, seq_len, pos_dim), with e.g. pos_dim=4: relatie positional
|
||||||
|
embedding tensor.
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||||
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
|
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
|
||||||
@ -565,11 +569,14 @@ class SubformerEncoder(nn.Module):
|
|||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_begin: float,
|
warmup_begin: float,
|
||||||
warmup_end: float,
|
warmup_end: float,
|
||||||
|
chunk_size: int = 256,
|
||||||
initial_layerdrop_rate: float = 0.5,
|
initial_layerdrop_rate: float = 0.5,
|
||||||
final_layerdrop_rate: float = 0.05,
|
final_layerdrop_rate: float = 0.05,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
@ -595,7 +602,6 @@ class SubformerEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
feature_mask: Optional[Tensor] = None,
|
|
||||||
attn_offset: Optional[Tensor] = None,
|
attn_offset: Optional[Tensor] = None,
|
||||||
memory: Optional[Tensor] = None,
|
memory: Optional[Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
@ -623,23 +629,107 @@ class SubformerEncoder(nn.Module):
|
|||||||
|
|
||||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||||
|
|
||||||
if feature_mask is not None:
|
#if feature_mask is not None:
|
||||||
output = output * feature_mask
|
# output = output * feature_mask
|
||||||
|
|
||||||
|
|
||||||
|
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
|
||||||
|
|
||||||
|
pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ]
|
||||||
|
attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, c) for c in chunk_sizes ]
|
||||||
|
# TODO: support this for memory also; would require duplicating it maybe;
|
||||||
|
# or could modify the interior code to just assume chunking
|
||||||
|
# when doing cross-attention.
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
|
ci = chunk_indexes[i]
|
||||||
|
c = chunk_sizes[ci]
|
||||||
|
output = self._to_chunk_size(output, c)
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_embs[ci],
|
||||||
attn_offset=attn_offset,
|
attn_offset=attn_offset[ci],
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if feature_mask is not None:
|
#if feature_mask is not None:
|
||||||
output = output * feature_mask
|
# output = output * feature_mask
|
||||||
|
|
||||||
|
output = self._to_chunk_size(output, src.shape[0])
|
||||||
|
|
||||||
return self.bypass(src, output)
|
return self.bypass(src, output)
|
||||||
|
|
||||||
|
def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]:
|
||||||
|
"""
|
||||||
|
Decide the chunk sizes (in frames) to use for each layer.
|
||||||
|
Args:
|
||||||
|
src: the input embeddings, of shape (seq_len, batch_size, embed_dim)
|
||||||
|
Returns: (chunk_sizes, chunk_indexes), where:
|
||||||
|
chunk_sizes: a list of the unique chunk sizes to use, e.g. [ 128, 256 ]
|
||||||
|
chunk_indexes: a list of indexes into chunk_sizes, one per layer.
|
||||||
|
"""
|
||||||
|
seq_len = src.shape[0]
|
||||||
|
assert seq_len < self.chunk_size or seq_len % self.chunk_size == 0
|
||||||
|
if seq_len <= self.chunk_size:
|
||||||
|
return [ seq_len ], [ 0 ] * len(self.layers)
|
||||||
|
else:
|
||||||
|
assert seq_len % self.chunk_size == 0, (seq_len, self.chunk_size)
|
||||||
|
num_layers = len(self.layers)
|
||||||
|
chunk_indexes = [0, 1] * (num_layers + 1 // 2)
|
||||||
|
return [ self.chunk_size, self.chunk_size * 2 ], chunk_indexes[:num_layers]
|
||||||
|
|
||||||
|
|
||||||
|
def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Reshape embeddings 'src' to have a different chunk size (in frames)
|
||||||
|
"""
|
||||||
|
num_channels = src.shape[-1]
|
||||||
|
return src.reshape(chunk_size, -1, num_channels)
|
||||||
|
|
||||||
|
def _attn_offset_to_chunk_size(self, attn_offset: Tensor, chunk_size: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Break up attention offset into a given chunk size
|
||||||
|
"""
|
||||||
|
(batch_size, seq_len, seq_len) = attn_offset.shape
|
||||||
|
if seq_len == chunk_size:
|
||||||
|
return attn_offset
|
||||||
|
assert seq_len % chunk_size == 0
|
||||||
|
|
||||||
|
num_chunks = seq_len / chunk_size
|
||||||
|
|
||||||
|
batch_stride, tgt_stride, src_stride = attn_offset.stride()
|
||||||
|
|
||||||
|
# have the 'chunk' dimension first so it has larger stride than the original batch; this
|
||||||
|
# is to match what happens to the embeddings in 'src' where the time-stride is first.
|
||||||
|
attn_offset = attn_offset.as_strided((num_chunks, batch_size, chunk_size, chunk_size),
|
||||||
|
((tgt_stride + src_stride) * chunk_size, batch_stride,
|
||||||
|
tgt_stride, src_stride))
|
||||||
|
|
||||||
|
return attn_offset.contiguous().reshape(num_chunks * batch_size, chunk_size, chunk_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _pos_emb_to_chunk_size(self, pos_emb: Tensor, chunk_size: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Break up positional embedding tensor into a given chunk size
|
||||||
|
"""
|
||||||
|
(batch_size, seq_len, seq_len, pos_dim) = pos_emb.shape
|
||||||
|
if seq_len == chunk_size:
|
||||||
|
return pos_emb
|
||||||
|
assert seq_len % chunk_size == 0
|
||||||
|
|
||||||
|
num_chunks = seq_len / chunk_size
|
||||||
|
|
||||||
|
batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride()
|
||||||
|
|
||||||
|
pos_emb = pos_emb.as_strided((num_chunks, batch_size, chunk_size, chunk_size, pos_dim),
|
||||||
|
((tgt_stride + src_stride) * chunk_size, batch_stride,
|
||||||
|
tgt_stride, src_stride, channel_stride))
|
||||||
|
|
||||||
|
return pos_emb.contiguous().reshape(num_chunks * batch_size,
|
||||||
|
chunk_size, chunk_size,
|
||||||
|
pos_dim)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BypassModule(nn.Module):
|
class BypassModule(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -1137,11 +1227,10 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
"""Create positional encoding.
|
"""Create positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (time, batch, num_channels_in)
|
x (torch.Tensor): Input tensor (seq_len, batch_size, num_channels_in)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
positional embedding, of shape (batch_size, 2*time-1, pos_dim).
|
positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
seq_len = x.size(0)
|
seq_len = x.size(0)
|
||||||
|
|||||||
@ -146,6 +146,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-chunk-size",
|
||||||
|
type=str,
|
||||||
|
default="128"
|
||||||
|
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
|
||||||
|
"double this value."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--query-head-dim",
|
"--query-head-dim",
|
||||||
type=str,
|
type=str,
|
||||||
@ -415,6 +423,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder = Subformer(
|
encoder = Subformer(
|
||||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||||
|
encoder_chunk_size=_to_int_tuple(params.encoder_chunk_size),
|
||||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||||
pos_dim=int(params.pos_dim),
|
pos_dim=int(params.pos_dim),
|
||||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user