From 8483ca2e8f06a00d22c1b1c6a1360e39c556c20c Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 24 May 2023 16:04:05 +0800 Subject: [PATCH] More partial work --- egs/libriheavy/LM/zipformer1/subformer.py | 173 +++++++++++++--------- egs/libriheavy/LM/zipformer1/train.py | 16 +- 2 files changed, 119 insertions(+), 70 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 7ad28a6cb..ae589813d 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -57,6 +57,10 @@ class Subformer(EncoderInterface): the whole stack to downsample.) encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per encoder stack (i.e. one per "S" in structure). + encoder_chunk_sizes (Tuple[Tuple[int]]): A tuple containing either one tuple or + one tuple per encoder stack. Each element tuple is a list of the chunk sizes + that we use during training, e.g. (128, 1024); we go through these round-robin + in successive layers. downsampling_factor (Tuple[int]): downsampling factor for each downsampling operation (each open-parenthesis). num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack @@ -85,7 +89,7 @@ class Subformer(EncoderInterface): structure: str = "S(S)S", encoder_dim: Tuple[int, ...] = (384, 512, 384), downsampling_factor: Tuple[int, ...] = (2,), - encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,), + encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128,1024),), num_encoder_layers: Union[int, Tuple[int, ...]] = (4,), query_head_dim: Tuple[int, ...] = (24,), value_head_dim: Tuple[int, ...] = (12,), @@ -120,7 +124,7 @@ class Subformer(EncoderInterface): return x self.encoder_dim = encoder_dim - encoder_chunk_size = _to_tuple(encoder_chunk_size) + encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes) num_encoder_layers = _to_tuple(num_encoder_layers) query_head_dim = _to_tuple(query_head_dim) value_head_dim = _to_tuple(value_head_dim) @@ -136,7 +140,17 @@ class Subformer(EncoderInterface): # each one will be SubformerEncoder or DownsampledSubformerEncoder encoders = [] downsamplers = [] - bypass = [] + bypasses = [] + + layer_indexes = [] + + cur_max_dim = encoder_dim[0] + + downsampling_factors_list = [] + def cur_downsampling_factor(): + c = 1 + for d in downsampling_factors_list: c *= d + return c for s in structure: if s == 'S': @@ -152,61 +166,45 @@ class Subformer(EncoderInterface): dropout=dropout, causal=causal, ) - + cur_max_dim = max(cur_max_dim, encoder_dim[i]) encoder = SubformerEncoder( encoder_layer, num_encoder_layers[i], + embed_dim=cur_max_dim, dropout=dropout, - chunk_size=encoder_chunk_size[i], + chunk_sizes=encoder_chunk_sizes[i], warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), - final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), + final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5), ) + layer_indexes.append(len(encoders)) encoders.append(encoder) - - pass elif s =='(': - pass + i = len(downsamplers) + downsampler = LearnedDownsamplingModule(cur_max_dim, + downsampling_factor[i]) + downsampling_factors_list.append(downsampling_factor[i]) + layer_indexes.append(len(downsamplers)) + downsamplers.append(downsampler) else: assert s == ')' + bypass = BypassModule(cur_max_dim, straight_through_rate=0.0) + layer_indexes.append(len(bypasses)) + bypasses.append(bypass) + downsampling_factors_list.pop() + logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}") - num_encoders = len(encoder_dim) - assert num_encoders % 2 == 1 - downsampling_factor = [ 1 ] - while len(downsampling_factor) < num_encoders: - downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ] - - for i in range(num_encoders): - - - mid = len(encoders) // 2 - encoder = DownsampledSubformerEncoder( - [ encoders[mid] ], - input_num_channels=encoder_dim[mid-1], - downsample=2 - ) - for i in range(1, mid+1): - this_list = [ encoders[mid-i], - encoder, - encoders[mid+i] ] - encoder = DownsampledSubformerEncoder( - this_list, - input_num_channels=encoder_dim[max(0, mid-i-1)], - downsample=2 if i != mid else 1 - ) - - self.encoder = encoder + self.layer_indexes = layer_indexes + self.structure = structure + self.encoders = nn.ModuleList(encoders) + self.downsamplers = nn.ModuleList(downsamplers) + self.bypasses = nn.ModuleList(bypasses) self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim, dropout_rate=0.15, length_factor=1.0) - #self.downsample_output = SimpleDownsample(max(encoder_dim), - # downsample=output_downsampling_factor, - # dropout=dropout) - - def forward( self, @@ -239,7 +237,6 @@ class Subformer(EncoderInterface): """ outputs = [] - attn_offset = self._get_attn_offset(x, src_key_padding_mask) if self.training and memory is not None: batch_size = x.shape[1] @@ -249,14 +246,42 @@ class Subformer(EncoderInterface): memory = memory * (torch.rand(batch_size, 1, device=memory.device) > memory_dropout_rate) - pos_emb = self.encoder_pos(x) + attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ] + pos_embs = [ self.encoder_pos(x) ] + downsample_info = [] + + for s, i in zip(self.structure, self.layer_indexes): + if s == 'S': + encoder = self.encoders[i] # one encoder stack + x = encoder(x, + pos_embs[-1], + attn_offset=attn_offsets[-1], + memory=memory, + memory_key_padding_mask=memory_key_padding_mask) + # x will have the maximum dimension up till now, even if + # `encoder` uses lower dim in its layers. + elif s == '(': + downsampler = self.downsamplers[i] + + indexes, weights, x_new = downsampler(x) + downsample_info.append((indexes, weights, x)) + x = x_new + + pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes)) + + attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1], + indexes, + weights)) + + else: + assert s == ')' # upsample + indexes, weights, x_orig = downsample_info.pop() + _attn_offset = attn_offsets.pop() + _pos_emb = pos_embs.pop() + x_orig = convert_num_channels(x_orig, x.shape[-1]) + + x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights) - x = self.encoder(x, - pos_emb, - attn_offset=attn_offset, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) # d = self.output_downsampling_factor # lengths = (x_lens + d - 1) // d @@ -575,6 +600,9 @@ class SubformerEncoder(nn.Module): Args: encoder_layer: an instance of the SubformerEncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). + embed_dim: the embedding dimension to use for the bypass (may exceed the + dimension of encoder_layer, as it may not operate on the full + dimension). Examples:: >>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8) @@ -586,6 +614,7 @@ class SubformerEncoder(nn.Module): self, encoder_layer: nn.Module, num_layers: int, + embed_dim: int, dropout: float, warmup_begin: float, warmup_end: float, @@ -602,7 +631,7 @@ class SubformerEncoder(nn.Module): ) self.num_layers = num_layers - self.bypass = BypassModule(self.embed_dim()) + self.bypass = BypassModule(embed_dim) assert 0 <= warmup_begin <= warmup_end @@ -616,7 +645,7 @@ class SubformerEncoder(nn.Module): cur_begin = cur_end def embed_dim(self): - return self.layers[0].embed_dim + return self.bypass.embed_dim() def forward( self, @@ -644,14 +673,7 @@ class SubformerEncoder(nn.Module): Returns: a Tensor with the same shape as src. """ - src = convert_num_channels(src, self.embed_dim()) - output = src - - rnd_seed = src.numel() + random.randint(0, 1000) - - #if feature_mask is not None: - # output = output * feature_mask - + output = convert_num_channels(src, self.layers[0].embed_dim) chunk_sizes, chunk_indexes = self._get_chunk_sizes(src) b = src.shape[1] # batch_size @@ -678,6 +700,9 @@ class SubformerEncoder(nn.Module): output = self._to_chunk_size(output, src.shape[0]) + output = convert_num_channels(output, self.bypass.embed_dim()) + src = convert_num_channels(src, self.bypass.embed_dim()) + return self.bypass(src, output) def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]: @@ -784,6 +809,8 @@ class BypassModule(nn.Module): self.scale_min = copy.deepcopy(scale_min) self.scale_max = copy.deepcopy(scale_max) + def embed_dim(self): + return self.bypass_scale.numel() def _get_bypass_scale(self, batch_size: int): # returns bypass-scale of shape (num_channels,), @@ -840,7 +867,7 @@ class LearnedDownsamplingModule(nn.Module): super().__init__() self.to_scores = nn.Linear(embed_dim, 1, bias=False) - self.to_scores.lr_factor = 0.5 + self.to_scores.lr_scale = 0.5 # score_balancer is just to keep the magnitudes of the scores in # a fixed range and keep them balanced around zero, to stop # these drifting around. @@ -1028,7 +1055,9 @@ class LearnedDownsamplingModule(nn.Module): return attn_offset - def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor: + @staticmethod + def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor, + weights: Optional[Tensor] = None) -> Tensor: """ Upsamples, reversing the downsample() operation and filling in any not-chosen frames with their original value before downsampling @@ -1038,30 +1067,40 @@ class LearnedDownsamplingModule(nn.Module): x_orig: (seq_len, batch_size, num_channels) x: (seq_len_reduced, batch_size, num_channels) indexes: (batch_size, seq_len_reduced), contains original frame indexes + weights: optional tensor Downsamples x via indexing with the indexes obtained from the forward() function. Args: - x: tensor of shape (seq_len, batch_size, indexes) + x: tensor of shape (seq_len, batch_size, indexes) + weights: a tensor of shape (batch_size, seq_len_reduced) containing weights between + 0 and 1, where 1 means fully use this x value and 0 means use x_orig indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements 0 <= indexes < seq_len. """ (seq_len, batch_size, num_channels) = x_orig.shape - not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool, - device=x.device) - not_kept.scatter_(dim=1, index=indexes, value=False) + x_weight = 1.0 if weights is None else weights.t().unsqueeze(-1) + # x_weight: (seq_len_reduced, batch_size, 1) if a tensor + + orig_x_weight = torch.ones(batch_size, seq_len, + device=x.device, dtype=x.dtype) + if weights is None: + orig_x_weight.scatter_(dim=1, index=indexes, value=0.) + else: + orig_x_weight.scatter_(dim=1, index=indexes, + src=(1. - weights).to(x.dtype)) indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels) # indexes now: (seq_len_reduced, batch_size, num_channels) ans = torch.zeros_like(x_orig) - ans.scatter_(dim=0, index=indexes, src=x) + ans.scatter_(dim=0, index=indexes, src=(x * x_weight)) # add in x_orig in the frames that were not originally kept. - return ans + x_orig * not_kept.t().unsqueeze(-1) + return ans + x_orig * orig_x_weight.t().unsqueeze(-1) class DownsampledSubformerEncoder(nn.Module): @@ -1151,7 +1190,7 @@ class DownsampledSubformerEncoder(nn.Module): src_orig = convert_num_channels(src_orig, src.shape[-1]) if hasattr(self, 'downsampler'): - src = self.downsampler.upsample(src_orig, src, indexes) + src = self.downsampler.upsample(src_orig, src, indexes, weights) return self.out_combiner(src_orig, src) diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 16e96bfb5..080bed8cc 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -118,6 +118,7 @@ def set_batch_count( def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( "--num-encoder-layers", type=str, @@ -147,13 +148,21 @@ def add_model_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( - "--encoder-chunk-size", + "--encoder-chunk-sizes", type=str, - default="128", + default="128,1024", help="Base chunk size for attention in encoder stacks; alternate layers will use this value or " "double this value." ) + parser.add_argument( + "--encoder-structure", + type=str, + default="S(S(S(S)S)S)S", + help="Structure of encoder, determines order of encoder stacks and (downsampling/upsampling) " + "operations." + ) + parser.add_argument( "--query-head-dim", type=str, @@ -421,9 +430,10 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module: #chunk_size = _to_int_tuple(params.downsampling_factor)[-1] encoder = Subformer( + structure=params.encoder_structure, num_encoder_layers=_to_int_tuple(params.num_encoder_layers), encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_chunk_size=_to_int_tuple(params.encoder_chunk_size), + encoder_chunk_sizes=(_to_int_tuple(params.encoder_chunk_sizes),), query_head_dim=_to_int_tuple(params.query_head_dim), pos_dim=int(params.pos_dim), value_head_dim=_to_int_tuple(params.value_head_dim),