diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index 24cc24658..175c7121c 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -289,6 +289,25 @@ class Subformer(EncoderInterface): # d = self.output_downsampling_factor # lengths = (x_lens + d - 1) // d + + # The next code block will only run in the case of "unbalanced" structures, e.g. + # if structure == "S(S(S)S", where there are unmatched right-parentheses. + cur_indexes = None + while len(downsample_info) > 0: + indexes, weights, x_orig = downsample_info.pop() + if cur_indexes is not None: + # keep only a subset of the indexes and weights, corresponding + # to later downsampling operations. + indexes = torch.gather(indexes, dim=1, index=cur_indexes) + weights = torch.gather(weights, dim=1, index=cur_indexes) + + cur_indexes = indexes + + x_lens = (weights != 0).sum(dim=1) + x_orig = convert_num_channels(x_orig, x.shape[-1]) + x_orig, x = LearnedDownsamplingModule.apply_weights(x_orig, x, indexes, weights) + + return x, x_lens def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]: @@ -1063,7 +1082,7 @@ class LearnedDownsamplingModule(nn.Module): x_orig: (seq_len, batch_size, num_channels) x: (seq_len_reduced, batch_size, num_channels) indexes: (batch_size, seq_len_reduced), contains original frame indexes - weights: optional tensor + weights: optional tensor of shape (batch_size, seq_len_reduced) Downsamples x via indexing with the indexes obtained from the forward() function. @@ -1098,97 +1117,34 @@ class LearnedDownsamplingModule(nn.Module): # add in x_orig in the frames that were not originally kept. return ans + x_orig * orig_x_weight.t().unsqueeze(-1) - -class DownsampledSubformerEncoder(nn.Module): - """ - DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced - frame rate, after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - def __init__(self, - encoders: List[nn.Module], - input_num_channels: int, - downsample: int): - super(DownsampledSubformerEncoder, self).__init__() - if downsample != 1: - self.downsampler = LearnedDownsamplingModule(input_num_channels, - downsample) - - self.encoders = nn.ModuleList(encoders) - - self.out_combiner = BypassModule(self.embed_dim(), - straight_through_rate=0.0) - - def embed_dim(self): # return output embed_dim which is max dim. - return max(e.embed_dim() for e in self.encoders) - - def forward(self, - src: Tensor, - pos_emb: Tensor, - attn_offset: Tensor, - feature_mask: Union[Tensor, float] = 1.0, - memory: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r"""Downsample, go through encoder, upsample. + @staticmethod + def apply_weights(x_orig: Tensor, x: Tensor, indexes: Tensor, + weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + """ + Downsamples x_orig to have the same shape as x and applies the weights, + returning interpolated x and downsampled x_orig. This is similar to + `upsample`, but is for the case where you don't want to keep the frames + that were not sampled. Args: - src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim) - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) - attn_offset: the attention offset, added to scores for attention of shape - (batch_size, seq_len, seq_len) or (seq_len, seq_len), - interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) - memory_key_padding_mask: optionally the mask for padding of memory input (for source- - attention), of shape (batch_size, memory_len); True means - masked position. May be None. + x_orig: (seq_len, batch_size, num_channels) + x: (seq_len_reduced, batch_size, num_channels) + indexes: (batch_size, seq_len_reduced), contains original frame indexes + weights: optional tensor of shape (batch_size, seq_len_reduced) - Returns: a Tensor with the same shape as src. + Returns (x_orig, x) after the downsampling and interpolation, of shapes + both (seq_len_reduced, batch_size, num_channels). """ - src_orig = src - if hasattr(self, 'downsampler'): - indexes, weights, src = self.downsampler(src) + (seq_len, batch_size, num_channels) = x_orig.shape + weights = 1.0 if weights is None else weights.t().unsqueeze(-1) - pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes) + indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels) + # indexes now: (seq_len_reduced, batch_size, num_channels) + x_orig = torch.gather(x_orig, dim=0, index=indexes) - attn_offset = self.downsampler.downsample_attn_offset(attn_offset, - indexes, - weights) - outputs = [ src ] + x = x * weights + x_orig * (1.0 - weights) - for encoder in self.encoders: - src = encoder( - src, - pos_emb, - attn_offset=attn_offset, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - ) - outputs.append(src) - - def get_full_dim_output(): - num_encoders = len(outputs) - output_dim = max(o.shape[-1] for o in outputs) - output_pieces = [ outputs[-1] ] - cur_dim = outputs[-1].shape[-1] - for i in range(num_encoders - 2, -1, -1): - d = outputs[i].shape[-1] - if d > cur_dim: - this_output = outputs[i] - output_pieces.append(this_output[..., cur_dim:d]) - cur_dim = d - assert cur_dim == output_dim - return torch.cat(output_pieces, dim=-1) - - src = get_full_dim_output() - src_orig = convert_num_channels(src_orig, src.shape[-1]) - - if hasattr(self, 'downsampler'): - src = self.downsampler.upsample(src_orig, src, indexes, weights) - - return self.out_combiner(src_orig, src) + return x_orig, x @@ -1881,6 +1837,7 @@ def _test_zipformer_main(causal: bool = False): memory_dim = 100 c = Subformer( + structure = "S(S)S" if causal else "S(S(S", encoder_dim=(64, 96, 64), num_heads=(4, 4, 8), causal=causal,