Support unbalanced structures

This commit is contained in:
Daniel Povey 2023-05-29 13:13:29 +08:00
parent b85012aa0b
commit 0f27b14376

View File

@ -289,6 +289,25 @@ class Subformer(EncoderInterface):
# d = self.output_downsampling_factor # d = self.output_downsampling_factor
# lengths = (x_lens + d - 1) // d # lengths = (x_lens + d - 1) // d
# The next code block will only run in the case of "unbalanced" structures, e.g.
# if structure == "S(S(S)S", where there are unmatched right-parentheses.
cur_indexes = None
while len(downsample_info) > 0:
indexes, weights, x_orig = downsample_info.pop()
if cur_indexes is not None:
# keep only a subset of the indexes and weights, corresponding
# to later downsampling operations.
indexes = torch.gather(indexes, dim=1, index=cur_indexes)
weights = torch.gather(weights, dim=1, index=cur_indexes)
cur_indexes = indexes
x_lens = (weights != 0).sum(dim=1)
x_orig = convert_num_channels(x_orig, x.shape[-1])
x_orig, x = LearnedDownsamplingModule.apply_weights(x_orig, x, indexes, weights)
return x, x_lens return x, x_lens
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]: def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
@ -1063,7 +1082,7 @@ class LearnedDownsamplingModule(nn.Module):
x_orig: (seq_len, batch_size, num_channels) x_orig: (seq_len, batch_size, num_channels)
x: (seq_len_reduced, batch_size, num_channels) x: (seq_len_reduced, batch_size, num_channels)
indexes: (batch_size, seq_len_reduced), contains original frame indexes indexes: (batch_size, seq_len_reduced), contains original frame indexes
weights: optional tensor weights: optional tensor of shape (batch_size, seq_len_reduced)
Downsamples x via indexing with the indexes obtained from the Downsamples x via indexing with the indexes obtained from the
forward() function. forward() function.
@ -1098,97 +1117,34 @@ class LearnedDownsamplingModule(nn.Module):
# add in x_orig in the frames that were not originally kept. # add in x_orig in the frames that were not originally kept.
return ans + x_orig * orig_x_weight.t().unsqueeze(-1) return ans + x_orig * orig_x_weight.t().unsqueeze(-1)
@staticmethod
class DownsampledSubformerEncoder(nn.Module): def apply_weights(x_orig: Tensor, x: Tensor, indexes: Tensor,
""" weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced """
frame rate, after convolutional downsampling, and then upsampled again at the output, and combined Downsamples x_orig to have the same shape as x and applies the weights,
with the origin input, so that the output has the same shape as the input. returning interpolated x and downsampled x_orig. This is similar to
""" `upsample`, but is for the case where you don't want to keep the frames
def __init__(self, that were not sampled.
encoders: List[nn.Module],
input_num_channels: int,
downsample: int):
super(DownsampledSubformerEncoder, self).__init__()
if downsample != 1:
self.downsampler = LearnedDownsamplingModule(input_num_channels,
downsample)
self.encoders = nn.ModuleList(encoders)
self.out_combiner = BypassModule(self.embed_dim(),
straight_through_rate=0.0)
def embed_dim(self): # return output embed_dim which is max dim.
return max(e.embed_dim() for e in self.encoders)
def forward(self,
src: Tensor,
pos_emb: Tensor,
attn_offset: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
r"""Downsample, go through encoder, upsample.
Args: Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). x_orig: (seq_len, batch_size, num_channels)
pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim) x: (seq_len_reduced, batch_size, num_channels)
feature_mask: something that broadcasts with src, that we'll multiply `src` indexes: (batch_size, seq_len_reduced), contains original frame indexes
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) weights: optional tensor of shape (batch_size, seq_len_reduced)
attn_offset: the attention offset, added to scores for attention of shape
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means
masked position. May be None.
Returns: a Tensor with the same shape as src. Returns (x_orig, x) after the downsampling and interpolation, of shapes
both (seq_len_reduced, batch_size, num_channels).
""" """
src_orig = src (seq_len, batch_size, num_channels) = x_orig.shape
if hasattr(self, 'downsampler'): weights = 1.0 if weights is None else weights.t().unsqueeze(-1)
indexes, weights, src = self.downsampler(src)
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes) indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
# indexes now: (seq_len_reduced, batch_size, num_channels)
x_orig = torch.gather(x_orig, dim=0, index=indexes)
attn_offset = self.downsampler.downsample_attn_offset(attn_offset, x = x * weights + x_orig * (1.0 - weights)
indexes,
weights)
outputs = [ src ]
for encoder in self.encoders: return x_orig, x
src = encoder(
src,
pos_emb,
attn_offset=attn_offset,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
outputs.append(src)
def get_full_dim_output():
num_encoders = len(outputs)
output_dim = max(o.shape[-1] for o in outputs)
output_pieces = [ outputs[-1] ]
cur_dim = outputs[-1].shape[-1]
for i in range(num_encoders - 2, -1, -1):
d = outputs[i].shape[-1]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim
return torch.cat(output_pieces, dim=-1)
src = get_full_dim_output()
src_orig = convert_num_channels(src_orig, src.shape[-1])
if hasattr(self, 'downsampler'):
src = self.downsampler.upsample(src_orig, src, indexes, weights)
return self.out_combiner(src_orig, src)
@ -1881,6 +1837,7 @@ def _test_zipformer_main(causal: bool = False):
memory_dim = 100 memory_dim = 100
c = Subformer( c = Subformer(
structure = "S(S)S" if causal else "S(S(S",
encoder_dim=(64, 96, 64), encoder_dim=(64, 96, 64),
num_heads=(4, 4, 8), num_heads=(4, 4, 8),
causal=causal, causal=causal,