From e51a2c91700e2a623bd568bb0fb717854ac4b944 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 23 May 2023 14:01:04 +0800 Subject: [PATCH] Partial work --- egs/libriheavy/LM/zipformer1/subformer.py | 142 +++++++++++++--------- 1 file changed, 87 insertions(+), 55 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 6b92b5fc8..7ad28a6cb 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -36,6 +36,7 @@ from scaling import ( ScheduledFloat, FloatLike, limit_param_value, + clip_grad, convert_num_channels, ) from torch import Tensor, nn @@ -49,14 +50,15 @@ class Subformer(EncoderInterface): as downsampling_factor if they are single ints or one-element tuples. The length of downsampling_factor defines the number of stacks. - output_downsampling_factor (int): how much to downsample at the output. Note: - we also downsample by a factor of 2 in the Conv2dSubsampling encoder. - You should probably leave this at 2. - downsampling_factor (Tuple[int]): downsampling factor for each encoder stack. - Note: this is in addition to the downsampling factor of 2 that is applied in - the frontend (self.encoder_embed). + + structure (str): determines the structure of the module, S is encoder stack, + open-parenthesis is downsampling operation, close-parenthesis is a corresponding + upsampling operation (but not all parentheses have to be closed if you want + the whole stack to downsample.) encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per - encoder stack. + encoder stack (i.e. one per "S" in structure). + downsampling_factor (Tuple[int]): downsampling factor for each downsampling + operation (each open-parenthesis). num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack query_head_dim (int or Tuple[int]): dimension of query and key per attention head: per stack, if a tuple.. @@ -80,13 +82,15 @@ class Subformer(EncoderInterface): """ def __init__( self, - encoder_dim: Union[int, Tuple[int]] = (384, 512, 384), - encoder_chunk_size: Union[int, Tuple[int]] = 128, - num_encoder_layers: Union[int, Tuple[int]] = 4, - query_head_dim: Union[int, Tuple[int]] = 24, - value_head_dim: Union[int, Tuple[int]] = 12, - num_heads: Union[int, Tuple[int]] = 8, - feedforward_dim: Union[int, Tuple[int]] = 1536, + structure: str = "S(S)S", + encoder_dim: Tuple[int, ...] = (384, 512, 384), + downsampling_factor: Tuple[int, ...] = (2,), + encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,), + num_encoder_layers: Union[int, Tuple[int, ...]] = (4,), + query_head_dim: Tuple[int, ...] = (24,), + value_head_dim: Tuple[int, ...] = (12,), + num_heads: Tuple[int, ...] = (8,), + feedforward_dim: Tuple[int, ...] = (1536,), memory_dim: int = -1, pos_dim: int = 4, dropout: Optional[FloatLike] = None, # see code below for default @@ -99,15 +103,20 @@ class Subformer(EncoderInterface): dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + num_encoders = len([s for s in structure if s == 'S']) + num_downsamplers = len([s for s in structure if s == '(']) + # when we upsample, we use the same downsampling object that we + # downsampled with, but we also need a BypassModule at that point. + num_bypass = len([s for s in structure if s == ')']) + def _to_tuple(x): """ Converts a single int or a 1-tuple of an int to a tuple with the same length - as encoder_dim""" - if isinstance(x, int): - x = (x,) + as num_encoders""" + assert isinstance(x, tuple) if len(x) == 1: - x = x * len(encoder_dim) + x = x * num_encoders else: - assert len(x) == len(encoder_dim) and isinstance(x[0], int) + assert len(x) == num_encoders return x self.encoder_dim = encoder_dim @@ -120,8 +129,46 @@ class Subformer(EncoderInterface): self.causal = causal + if len(downsampling_factor) == 1: + downsampling_factor = downsampling_factor * num_downsamplers + assert len(downsampling_factor) == num_downsamplers + # each one will be SubformerEncoder or DownsampledSubformerEncoder encoders = [] + downsamplers = [] + bypass = [] + + for s in structure: + if s == 'S': + i = len(encoders) + encoder_layer = SubformerEncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + memory_dim=memory_dim, + dropout=dropout, + causal=causal, + ) + + encoder = SubformerEncoder( + encoder_layer, + num_encoder_layers[i], + dropout=dropout, + chunk_size=encoder_chunk_size[i], + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + ) + encoders.append(encoder) + + pass + elif s =='(': + pass + else: + assert s == ')' num_encoders = len(encoder_dim) @@ -132,31 +179,6 @@ class Subformer(EncoderInterface): for i in range(num_encoders): - encoder_layer = SubformerEncoderLayer( - embed_dim=encoder_dim[i], - pos_dim=pos_dim, - num_heads=num_heads[i], - query_head_dim=query_head_dim[i], - value_head_dim=value_head_dim[i], - feedforward_dim=feedforward_dim[i], - memory_dim=memory_dim, - dropout=dropout, - causal=causal, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = SubformerEncoder( - encoder_layer, - num_encoder_layers[i], - dropout=dropout, - chunk_size=encoder_chunk_size[i], - warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), - ) - - encoders.append(encoder) mid = len(encoders) // 2 encoder = DownsampledSubformerEncoder( @@ -567,13 +589,13 @@ class SubformerEncoder(nn.Module): dropout: float, warmup_begin: float, warmup_end: float, - chunk_size: int = 256, + chunk_sizes: Tuple[int, ...] = (128, 2048), initial_layerdrop_rate: float = 0.5, final_layerdrop_rate: float = 0.05, ) -> None: super().__init__() - self.chunk_size = chunk_size + self.chunk_sizes = chunk_sizes self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -668,13 +690,17 @@ class SubformerEncoder(nn.Module): chunk_indexes: a list of indexes into chunk_sizes, one per layer. """ seq_len = src.shape[0] - if seq_len <= self.chunk_size or seq_len % self.chunk_size != 0: - return [ seq_len ], [ 0 ] * len(self.layers) - else: - num_layers = len(self.layers) - chunk_indexes = [0, 1] * (num_layers + 1 // 2) - return [ self.chunk_size, seq_len ], chunk_indexes[:num_layers] + chunk_indexes = [] + chunk_sizes = [] + for i, chunk_size in enumerate(self.chunk_sizes): + chunk_sizes.append(chunk_size if seq_len % chunk_size == 0 + else seq_len) + num_chunk_sizes = len(self.chunk_sizes) + for i in range(self.num_layers): + chunk_indexes.append(i % num_chunk_sizes) + + return chunk_sizes, chunk_indexes def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor: """ @@ -809,6 +835,7 @@ class LearnedDownsamplingModule(nn.Module): def __init__(self, embed_dim: int, downsampling_factor: int): + assert downsampling_factor > 1 super().__init__() @@ -864,9 +891,8 @@ class LearnedDownsamplingModule(nn.Module): d = self.downsampling_factor seq_len_reduced = (seq_len + d - 1) // d - weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced] - missing = weights_discarded.shape[1] - seq_len_reduced + missing = seq_len_reduced - weights_discarded.shape[1] if missing != 0: weights_discarded = torch.cat((weights_discarded, torch.zeros(batch_size, missing, @@ -986,6 +1012,12 @@ class LearnedDownsamplingModule(nn.Module): assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len) attn_offset = attn_offset.expand(batch_size, seq_len, seq_len) + if torch.is_autocast_enabled(): + # it's possible to get large gradients at this point; clip these at + # this point to reduce the extent to which it has to reduce the + # grad_scale. + weights = clip_grad(weights, 5000.0) + attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand( batch_size, seq_len_reduced, seq_len)) attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand( @@ -1849,7 +1881,7 @@ def _test_zipformer_main(causal: bool = False): memory_dim=memory_dim, ) batch_size = 5 - seq_len = 20 + seq_len = 128 # Just make sure the forward pass runs. f = c( torch.randn(seq_len, batch_size, 64),