diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index bff07d8ea..f81062256 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -27,7 +27,6 @@ from scaling import ( Balancer, BiasNorm, Dropout2, - ChunkCausalDepthwiseConv1d, ActivationDropoutAndLinear, ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, @@ -42,7 +41,7 @@ from scaling import ( from torch import Tensor, nn -class Zipformer2(EncoderInterface): +class Subformer2(EncoderInterface): """ Args: @@ -70,7 +69,6 @@ class Zipformer2(EncoderInterface): num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. Must be at least 4. feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules - cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module pos_dim (int): the dimension of each positional-encoding vector prior to projection, e.g. 128. @@ -83,15 +81,9 @@ class Zipformer2(EncoderInterface): slightly slower and use more memory. Enables use of the chunk_size and left_context_chunks options in forward(), which simulates streaming decoding. - chunk_size: (list of int): only set this to other than [-1] if causal; - the chunk size will be randomly chosen from this list. -1 means no chunking. - left_context_frames: (list of int): determines the number of left- - context chunks for causal training; will be rounded to a number of - chunks. Must not be less than cnn_module_kernel (after factoring in - rounding and downsampling); an error will be thrown if this is violated. memory_dim: if supplied and >0, will be the dimension of the memory embeddings passed into the zipformer (e.g. this might be the output of another - Zipformer used to create embedding vectors.) + Subformer used to create embedding vectors.) """ def __init__( self, @@ -105,16 +97,13 @@ class Zipformer2(EncoderInterface): value_head_dim: Union[int, Tuple[int]] = 12, num_heads: Union[int, Tuple[int]] = 8, feedforward_dim: Union[int, Tuple[int]] = 1536, - cnn_module_kernel: Union[int, Tuple[int]] = 31, memory_dim: int = -1, pos_dim: int = 192, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, causal: bool = False, - chunk_size: Tuple[int] = (-1,), - left_context_frames: Tuple[int] = (-1,), ) -> None: - super(Zipformer2, self).__init__() + super(Subformer2, self).__init__() if dropout is None: dropout = ScheduledFloat((0.0, 0.3), @@ -141,22 +130,17 @@ class Zipformer2(EncoderInterface): pos_head_dim = _to_tuple(pos_head_dim) num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - - self.causal = causal - self.chunk_size = chunk_size - self.left_context_frames = left_context_frames for u,d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d - # each one will be Zipformer2Encoder or DownsampledZipformer2Encoder + # each one will be Subformer2Encoder or DownsampledSubformer2Encoder encoders = [] num_encoders = len(downsampling_factor) for i in range(num_encoders): - encoder_layer = Zipformer2EncoderLayer( + encoder_layer = Subformer2EncoderLayer( embed_dim=encoder_dim[i], pos_dim=pos_dim, num_heads=num_heads[i], @@ -166,13 +150,12 @@ class Zipformer2(EncoderInterface): feedforward_dim=feedforward_dim[i], memory_dim=memory_dim, dropout=dropout, - cnn_module_kernel=cnn_module_kernel[i], causal=causal, ) # For the segment of the warmup period, we let the Conv2dSubsampling # layer learn something. Then we start to warm up the other encoders. - encoder = Zipformer2Encoder( + encoder = Subformer2Encoder( encoder_layer, num_encoder_layers[i], pos_dim=pos_dim, @@ -183,7 +166,7 @@ class Zipformer2(EncoderInterface): ) if downsampling_factor[i] != 1: - encoder = DownsampledZipformer2Encoder( + encoder = DownsampledSubformer2Encoder( encoder, dim=encoder_dim[i], downsample=downsampling_factor[i], @@ -257,24 +240,6 @@ class Zipformer2(EncoderInterface): return feature_masks - def get_chunk_info(self) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - if not self.causal: - return -1, -1 - chunk_size = random.choice(self.chunk_size) - if chunk_size == -1: - left_context_chunks = -1 - else: - left_context_frames = random.choice(self.left_context_frames) - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - if left_context_chunks == 0: - left_context_chunks = 1 - return chunk_size, left_context_chunks - - def forward( self, x: torch.Tensor, @@ -307,9 +272,7 @@ class Zipformer2(EncoderInterface): outputs = [] feature_masks = self.get_feature_masks(x) - chunk_size, left_context_chunks = self.get_chunk_info() - - attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + attn_mask = self._get_attn_mask(x) if self.training and memory is not None: batch_size = x.shape[1] @@ -361,45 +324,31 @@ class Zipformer2(EncoderInterface): return x, lengths - def _get_attn_mask(self, x: Tensor, - chunk_size: int, - left_context_chunks: int - ) -> Optional[Tensor]: + def _get_attn_mask(self, x: Tensor) -> Optional[Tensor]: """ - Return None if chunk_size == -1, else return attention mask of shape + Return None if not self.causal is false else return attention mask of shape (seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True means a masked position. Args: x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). chunk_size: chunk size, must divide """ - if chunk_size <= 0: + if not self.causal: return None - assert all(chunk_size % d == 0 for d in self.downsampling_factor) - if left_context_chunks >= 0: - num_encoders = len(self.encoder_dim) - assert all (chunk_size * left_context_chunks >= - (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] - for i in range(num_encoders)) - else: - left_context_chunks = 1000000 seq_len = x.shape[0] # t is frame index, shape (seq_len,) t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - # c is chunk index for each frame, shape (seq_len,) - c = t // chunk_size src_c = c tgt_c = c.unsqueeze(-1) - attn_mask = torch.logical_or(src_c > tgt_c, - src_c < tgt_c - left_context_chunks) - if __name__ == "__main__": - logging.info(f"attn_mask = {attn_mask}") + attn_mask = (src_c > tgt_c) + return attn_mask + def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: return ScheduledFloat((0.0, x), (20000.0, ratio * x), @@ -410,17 +359,16 @@ def _balancer_schedule(min_prob: float): -class Zipformer2EncoderLayer(nn.Module): +class Subformer2EncoderLayer(nn.Module): """ Args: embed_dim: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). feedforward_dim: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) + >>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8) >>> src = torch.rand(10, 32, 512) >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) @@ -435,7 +383,6 @@ class Zipformer2EncoderLayer(nn.Module): value_head_dim: int, feedforward_dim: int, dropout: FloatLike = 0.1, - cnn_module_kernel: int = 31, causal: bool = False, memory_dim: int = -1, attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), @@ -445,7 +392,7 @@ class Zipformer2EncoderLayer(nn.Module): ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), ) -> None: - super(Zipformer2EncoderLayer, self).__init__() + super(Subformer2EncoderLayer, self).__init__() self.embed_dim = embed_dim # self.bypass implements layer skipping as well as bypass; see its default values. @@ -509,14 +456,6 @@ class Zipformer2EncoderLayer(nn.Module): self.nonlin_attention = NonlinAttention(embed_dim, hidden_channels=3 * embed_dim // 4) - self.conv_module1 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) - - self.conv_module2 = ConvolutionModule(embed_dim, - cnn_module_kernel, - causal=causal) - #self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) @@ -682,10 +621,6 @@ class Zipformer2EncoderLayer(nn.Module): src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights), attention_skip_rate) - src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - float(self.conv_skip_rate)) - src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), float(self.ff2_skip_rate)) @@ -701,10 +636,6 @@ class Zipformer2EncoderLayer(nn.Module): src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights), attention_skip_rate) - src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, - src_key_padding_mask=src_key_padding_mask), - float(self.conv_skip_rate)) - src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), float(self.ff3_skip_rate)) @@ -718,17 +649,17 @@ class Zipformer2EncoderLayer(nn.Module): return src -class Zipformer2Encoder(nn.Module): - r"""Zipformer2Encoder is a stack of N encoder layers +class Subformer2Encoder(nn.Module): + r"""Subformer2Encoder is a stack of N encoder layers Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + encoder_layer: an instance of the Subformer2EncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). pos_dim: the dimension for the relative positional encoding Examples:: - >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) - >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6) + >>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = Subformer2Encoder(encoder_layer, num_layers=6) >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ @@ -874,9 +805,9 @@ class BypassModule(nn.Module): -class DownsampledZipformer2Encoder(nn.Module): +class DownsampledSubformer2Encoder(nn.Module): r""" - DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, + DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate, after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ @@ -885,7 +816,7 @@ class DownsampledZipformer2Encoder(nn.Module): dim: int, downsample: int, dropout: FloatLike): - super(DownsampledZipformer2Encoder, self).__init__() + super(DownsampledSubformer2Encoder, self).__init__() self.downsample_factor = downsample self.downsample = SimpleDownsample(dim, downsample, dropout) @@ -1577,7 +1508,7 @@ class MultiheadAttentionWeights(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer2 model. + """Feedforward module in Subformer2 model. """ def __init__(self, embed_dim: int, @@ -1718,137 +1649,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) return x -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer2 model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - def __init__( - self, channels: int, kernel_size: int, causal: bool, - ) -> None: - """Construct a ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - bottleneck_dim = channels - self.causal = causal - - self.in_proj = nn.Linear( - channels, 2 * bottleneck_dim, - ) - # the gradients on in_proj are a little noisy, likely to do with the - # sigmoid in glu. - - # after in_proj we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.balancer1 = Balancer( - bottleneck_dim, channel_dim=-1, - min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)), - max_positive=1.0, - min_abs=1.5, - max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0), - ) - - self.activation1 = Identity() # for diagnostics - - self.sigmoid = nn.Sigmoid() - - self.activation2 = Identity() # for diagnostics - - assert kernel_size % 2 == 1 - - self.depthwise_conv = ChunkCausalDepthwiseConv1d( - channels=bottleneck_dim, - kernel_size=kernel_size) if causal else nn.Conv1d( - in_channels=bottleneck_dim, - out_channels=bottleneck_dim, - groups=bottleneck_dim, - kernel_size=kernel_size, - padding=kernel_size // 2) - - self.balancer2 = Balancer( - bottleneck_dim, channel_dim=1, - min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), - max_positive=1.0, - min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)), - max_abs=10.0, - ) - - self.whiten = Whiten(num_groups=1, - whitening_limit=_whitening_schedule(7.5), - prob=(0.025, 0.25), - grad_scale=0.01) - - self.out_proj = ActivationDropoutAndLinear( - bottleneck_dim, channels, activation='SwooshR', - dropout_p=0.0, initial_scale=0.05, - ) - - def forward(self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - chunk_size: int = -1, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains True in masked positions. - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - - x = self.in_proj(x) # (time, batch, 2*channels) - - x, s = x.chunk(2, dim=-1) - s = self.balancer1(s) - s = self.sigmoid(s) - x = self.activation1(x) # identity. - x = x * s - x = self.activation2(x) # identity - - # (time, batch, channels) - - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - if chunk_size >= 0: - assert self.causal, "Must initialize model with causal=True if you use chunk_size" - x = self.depthwise_conv(x, chunk_size=chunk_size) - else: - x = self.depthwise_conv(x) - - x = self.balancer2(x) - x = x.permute(2, 0, 1) # (time, batch, channels) - - x = self.whiten(x) # (time, batch, channels) - x = self.out_proj(x) # (time, batch, channels) - - return x - class ScalarMultiply(nn.Module): def __init__(self, scale: float): @@ -1865,7 +1665,7 @@ def _test_zipformer_main(causal: bool = False): # Just make sure the forward pass runs. memory_dim = 100 - c = Zipformer2( + c = Subformer2( encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), causal=causal, chunk_size=(4,) if causal else (-1,), diff --git a/egs/libriheavy/LM/zipformer1/zipformer.py b/egs/libriheavy/LM/zipformer1/zipformer.py deleted file mode 120000 index d053ea6de..000000000 --- a/egs/libriheavy/LM/zipformer1/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer2/zipformer.py \ No newline at end of file