Partial work

This commit is contained in:
Daniel Povey 2023-05-23 14:01:04 +08:00
parent bcc9971ebe
commit e51a2c9170

View File

@ -36,6 +36,7 @@ from scaling import (
ScheduledFloat,
FloatLike,
limit_param_value,
clip_grad,
convert_num_channels,
)
from torch import Tensor, nn
@ -49,14 +50,15 @@ class Subformer(EncoderInterface):
as downsampling_factor if they are single ints or one-element tuples. The length of
downsampling_factor defines the number of stacks.
output_downsampling_factor (int): how much to downsample at the output. Note:
we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
You should probably leave this at 2.
downsampling_factor (Tuple[int]): downsampling factor for each encoder stack.
Note: this is in addition to the downsampling factor of 2 that is applied in
the frontend (self.encoder_embed).
structure (str): determines the structure of the module, S is encoder stack,
open-parenthesis is downsampling operation, close-parenthesis is a corresponding
upsampling operation (but not all parentheses have to be closed if you want
the whole stack to downsample.)
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
encoder stack.
encoder stack (i.e. one per "S" in structure).
downsampling_factor (Tuple[int]): downsampling factor for each downsampling
operation (each open-parenthesis).
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
query_head_dim (int or Tuple[int]): dimension of query and key per attention
head: per stack, if a tuple..
@ -80,13 +82,15 @@ class Subformer(EncoderInterface):
"""
def __init__(
self,
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
encoder_chunk_size: Union[int, Tuple[int]] = 128,
num_encoder_layers: Union[int, Tuple[int]] = 4,
query_head_dim: Union[int, Tuple[int]] = 24,
value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8,
feedforward_dim: Union[int, Tuple[int]] = 1536,
structure: str = "S(S)S",
encoder_dim: Tuple[int, ...] = (384, 512, 384),
downsampling_factor: Tuple[int, ...] = (2,),
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,),
num_encoder_layers: Union[int, Tuple[int, ...]] = (4,),
query_head_dim: Tuple[int, ...] = (24,),
value_head_dim: Tuple[int, ...] = (12,),
num_heads: Tuple[int, ...] = (8,),
feedforward_dim: Tuple[int, ...] = (1536,),
memory_dim: int = -1,
pos_dim: int = 4,
dropout: Optional[FloatLike] = None, # see code below for default
@ -99,15 +103,20 @@ class Subformer(EncoderInterface):
dropout = ScheduledFloat((0.0, 0.3),
(20000.0, 0.1))
num_encoders = len([s for s in structure if s == 'S'])
num_downsamplers = len([s for s in structure if s == '('])
# when we upsample, we use the same downsampling object that we
# downsampled with, but we also need a BypassModule at that point.
num_bypass = len([s for s in structure if s == ')'])
def _to_tuple(x):
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
as encoder_dim"""
if isinstance(x, int):
x = (x,)
as num_encoders"""
assert isinstance(x, tuple)
if len(x) == 1:
x = x * len(encoder_dim)
x = x * num_encoders
else:
assert len(x) == len(encoder_dim) and isinstance(x[0], int)
assert len(x) == num_encoders
return x
self.encoder_dim = encoder_dim
@ -120,18 +129,18 @@ class Subformer(EncoderInterface):
self.causal = causal
if len(downsampling_factor) == 1:
downsampling_factor = downsampling_factor * num_downsamplers
assert len(downsampling_factor) == num_downsamplers
# each one will be SubformerEncoder or DownsampledSubformerEncoder
encoders = []
downsamplers = []
bypass = []
num_encoders = len(encoder_dim)
assert num_encoders % 2 == 1
downsampling_factor = [ 1 ]
while len(downsampling_factor) < num_encoders:
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
for i in range(num_encoders):
for s in structure:
if s == 'S':
i = len(encoders)
encoder_layer = SubformerEncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_dim,
@ -144,8 +153,6 @@ class Subformer(EncoderInterface):
causal=causal,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = SubformerEncoder(
encoder_layer,
num_encoder_layers[i],
@ -155,9 +162,24 @@ class Subformer(EncoderInterface):
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
)
encoders.append(encoder)
pass
elif s =='(':
pass
else:
assert s == ')'
num_encoders = len(encoder_dim)
assert num_encoders % 2 == 1
downsampling_factor = [ 1 ]
while len(downsampling_factor) < num_encoders:
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
for i in range(num_encoders):
mid = len(encoders) // 2
encoder = DownsampledSubformerEncoder(
[ encoders[mid] ],
@ -567,13 +589,13 @@ class SubformerEncoder(nn.Module):
dropout: float,
warmup_begin: float,
warmup_end: float,
chunk_size: int = 256,
chunk_sizes: Tuple[int, ...] = (128, 2048),
initial_layerdrop_rate: float = 0.5,
final_layerdrop_rate: float = 0.05,
) -> None:
super().__init__()
self.chunk_size = chunk_size
self.chunk_sizes = chunk_sizes
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
@ -668,13 +690,17 @@ class SubformerEncoder(nn.Module):
chunk_indexes: a list of indexes into chunk_sizes, one per layer.
"""
seq_len = src.shape[0]
if seq_len <= self.chunk_size or seq_len % self.chunk_size != 0:
return [ seq_len ], [ 0 ] * len(self.layers)
else:
num_layers = len(self.layers)
chunk_indexes = [0, 1] * (num_layers + 1 // 2)
return [ self.chunk_size, seq_len ], chunk_indexes[:num_layers]
chunk_indexes = []
chunk_sizes = []
for i, chunk_size in enumerate(self.chunk_sizes):
chunk_sizes.append(chunk_size if seq_len % chunk_size == 0
else seq_len)
num_chunk_sizes = len(self.chunk_sizes)
for i in range(self.num_layers):
chunk_indexes.append(i % num_chunk_sizes)
return chunk_sizes, chunk_indexes
def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor:
"""
@ -809,6 +835,7 @@ class LearnedDownsamplingModule(nn.Module):
def __init__(self,
embed_dim: int,
downsampling_factor: int):
assert downsampling_factor > 1
super().__init__()
@ -864,9 +891,8 @@ class LearnedDownsamplingModule(nn.Module):
d = self.downsampling_factor
seq_len_reduced = (seq_len + d - 1) // d
weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced]
missing = weights_discarded.shape[1] - seq_len_reduced
missing = seq_len_reduced - weights_discarded.shape[1]
if missing != 0:
weights_discarded = torch.cat((weights_discarded,
torch.zeros(batch_size, missing,
@ -986,6 +1012,12 @@ class LearnedDownsamplingModule(nn.Module):
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
if torch.is_autocast_enabled():
# it's possible to get large gradients at this point; clip these at
# this point to reduce the extent to which it has to reduce the
# grad_scale.
weights = clip_grad(weights, 5000.0)
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
batch_size, seq_len_reduced, seq_len))
attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
@ -1849,7 +1881,7 @@ def _test_zipformer_main(causal: bool = False):
memory_dim=memory_dim,
)
batch_size = 5
seq_len = 20
seq_len = 128
# Just make sure the forward pass runs.
f = c(
torch.randn(seq_len, batch_size, 64),