Partial work

This commit is contained in:
Daniel Povey 2023-05-23 14:01:04 +08:00
parent bcc9971ebe
commit e51a2c9170

View File

@ -36,6 +36,7 @@ from scaling import (
ScheduledFloat, ScheduledFloat,
FloatLike, FloatLike,
limit_param_value, limit_param_value,
clip_grad,
convert_num_channels, convert_num_channels,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -49,14 +50,15 @@ class Subformer(EncoderInterface):
as downsampling_factor if they are single ints or one-element tuples. The length of as downsampling_factor if they are single ints or one-element tuples. The length of
downsampling_factor defines the number of stacks. downsampling_factor defines the number of stacks.
output_downsampling_factor (int): how much to downsample at the output. Note:
we also downsample by a factor of 2 in the Conv2dSubsampling encoder. structure (str): determines the structure of the module, S is encoder stack,
You should probably leave this at 2. open-parenthesis is downsampling operation, close-parenthesis is a corresponding
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. upsampling operation (but not all parentheses have to be closed if you want
Note: this is in addition to the downsampling factor of 2 that is applied in the whole stack to downsample.)
the frontend (self.encoder_embed).
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
encoder stack. encoder stack (i.e. one per "S" in structure).
downsampling_factor (Tuple[int]): downsampling factor for each downsampling
operation (each open-parenthesis).
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
query_head_dim (int or Tuple[int]): dimension of query and key per attention query_head_dim (int or Tuple[int]): dimension of query and key per attention
head: per stack, if a tuple.. head: per stack, if a tuple..
@ -80,13 +82,15 @@ class Subformer(EncoderInterface):
""" """
def __init__( def __init__(
self, self,
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384), structure: str = "S(S)S",
encoder_chunk_size: Union[int, Tuple[int]] = 128, encoder_dim: Tuple[int, ...] = (384, 512, 384),
num_encoder_layers: Union[int, Tuple[int]] = 4, downsampling_factor: Tuple[int, ...] = (2,),
query_head_dim: Union[int, Tuple[int]] = 24, encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,),
value_head_dim: Union[int, Tuple[int]] = 12, num_encoder_layers: Union[int, Tuple[int, ...]] = (4,),
num_heads: Union[int, Tuple[int]] = 8, query_head_dim: Tuple[int, ...] = (24,),
feedforward_dim: Union[int, Tuple[int]] = 1536, value_head_dim: Tuple[int, ...] = (12,),
num_heads: Tuple[int, ...] = (8,),
feedforward_dim: Tuple[int, ...] = (1536,),
memory_dim: int = -1, memory_dim: int = -1,
pos_dim: int = 4, pos_dim: int = 4,
dropout: Optional[FloatLike] = None, # see code below for default dropout: Optional[FloatLike] = None, # see code below for default
@ -99,15 +103,20 @@ class Subformer(EncoderInterface):
dropout = ScheduledFloat((0.0, 0.3), dropout = ScheduledFloat((0.0, 0.3),
(20000.0, 0.1)) (20000.0, 0.1))
num_encoders = len([s for s in structure if s == 'S'])
num_downsamplers = len([s for s in structure if s == '('])
# when we upsample, we use the same downsampling object that we
# downsampled with, but we also need a BypassModule at that point.
num_bypass = len([s for s in structure if s == ')'])
def _to_tuple(x): def _to_tuple(x):
""" Converts a single int or a 1-tuple of an int to a tuple with the same length """ Converts a single int or a 1-tuple of an int to a tuple with the same length
as encoder_dim""" as num_encoders"""
if isinstance(x, int): assert isinstance(x, tuple)
x = (x,)
if len(x) == 1: if len(x) == 1:
x = x * len(encoder_dim) x = x * num_encoders
else: else:
assert len(x) == len(encoder_dim) and isinstance(x[0], int) assert len(x) == num_encoders
return x return x
self.encoder_dim = encoder_dim self.encoder_dim = encoder_dim
@ -120,8 +129,46 @@ class Subformer(EncoderInterface):
self.causal = causal self.causal = causal
if len(downsampling_factor) == 1:
downsampling_factor = downsampling_factor * num_downsamplers
assert len(downsampling_factor) == num_downsamplers
# each one will be SubformerEncoder or DownsampledSubformerEncoder # each one will be SubformerEncoder or DownsampledSubformerEncoder
encoders = [] encoders = []
downsamplers = []
bypass = []
for s in structure:
if s == 'S':
i = len(encoders)
encoder_layer = SubformerEncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_dim,
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
memory_dim=memory_dim,
dropout=dropout,
causal=causal,
)
encoder = SubformerEncoder(
encoder_layer,
num_encoder_layers[i],
dropout=dropout,
chunk_size=encoder_chunk_size[i],
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
)
encoders.append(encoder)
pass
elif s =='(':
pass
else:
assert s == ')'
num_encoders = len(encoder_dim) num_encoders = len(encoder_dim)
@ -132,31 +179,6 @@ class Subformer(EncoderInterface):
for i in range(num_encoders): for i in range(num_encoders):
encoder_layer = SubformerEncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_dim,
num_heads=num_heads[i],
query_head_dim=query_head_dim[i],
value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i],
memory_dim=memory_dim,
dropout=dropout,
causal=causal,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = SubformerEncoder(
encoder_layer,
num_encoder_layers[i],
dropout=dropout,
chunk_size=encoder_chunk_size[i],
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
)
encoders.append(encoder)
mid = len(encoders) // 2 mid = len(encoders) // 2
encoder = DownsampledSubformerEncoder( encoder = DownsampledSubformerEncoder(
@ -567,13 +589,13 @@ class SubformerEncoder(nn.Module):
dropout: float, dropout: float,
warmup_begin: float, warmup_begin: float,
warmup_end: float, warmup_end: float,
chunk_size: int = 256, chunk_sizes: Tuple[int, ...] = (128, 2048),
initial_layerdrop_rate: float = 0.5, initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05, final_layerdrop_rate: float = 0.05,
) -> None: ) -> None:
super().__init__() super().__init__()
self.chunk_size = chunk_size self.chunk_sizes = chunk_sizes
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
@ -668,13 +690,17 @@ class SubformerEncoder(nn.Module):
chunk_indexes: a list of indexes into chunk_sizes, one per layer. chunk_indexes: a list of indexes into chunk_sizes, one per layer.
""" """
seq_len = src.shape[0] seq_len = src.shape[0]
if seq_len <= self.chunk_size or seq_len % self.chunk_size != 0: chunk_indexes = []
return [ seq_len ], [ 0 ] * len(self.layers) chunk_sizes = []
else: for i, chunk_size in enumerate(self.chunk_sizes):
num_layers = len(self.layers) chunk_sizes.append(chunk_size if seq_len % chunk_size == 0
chunk_indexes = [0, 1] * (num_layers + 1 // 2) else seq_len)
return [ self.chunk_size, seq_len ], chunk_indexes[:num_layers]
num_chunk_sizes = len(self.chunk_sizes)
for i in range(self.num_layers):
chunk_indexes.append(i % num_chunk_sizes)
return chunk_sizes, chunk_indexes
def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor: def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor:
""" """
@ -809,6 +835,7 @@ class LearnedDownsamplingModule(nn.Module):
def __init__(self, def __init__(self,
embed_dim: int, embed_dim: int,
downsampling_factor: int): downsampling_factor: int):
assert downsampling_factor > 1
super().__init__() super().__init__()
@ -864,9 +891,8 @@ class LearnedDownsamplingModule(nn.Module):
d = self.downsampling_factor d = self.downsampling_factor
seq_len_reduced = (seq_len + d - 1) // d seq_len_reduced = (seq_len + d - 1) // d
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced] weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
missing = weights_discarded.shape[1] - seq_len_reduced missing = seq_len_reduced - weights_discarded.shape[1]
if missing != 0: if missing != 0:
weights_discarded = torch.cat((weights_discarded, weights_discarded = torch.cat((weights_discarded,
torch.zeros(batch_size, missing, torch.zeros(batch_size, missing,
@ -986,6 +1012,12 @@ class LearnedDownsamplingModule(nn.Module):
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len) assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len) attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
if torch.is_autocast_enabled():
# it's possible to get large gradients at this point; clip these at
# this point to reduce the extent to which it has to reduce the
# grad_scale.
weights = clip_grad(weights, 5000.0)
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand( attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
batch_size, seq_len_reduced, seq_len)) batch_size, seq_len_reduced, seq_len))
attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand( attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
@ -1849,7 +1881,7 @@ def _test_zipformer_main(causal: bool = False):
memory_dim=memory_dim, memory_dim=memory_dim,
) )
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 128
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
f = c( f = c(
torch.randn(seq_len, batch_size, 64), torch.randn(seq_len, batch_size, 64),