Support unbalanced structures

This commit is contained in:
Daniel Povey 2023-05-29 13:13:29 +08:00
parent b85012aa0b
commit 0f27b14376

View File

@ -289,6 +289,25 @@ class Subformer(EncoderInterface):
# d = self.output_downsampling_factor
# lengths = (x_lens + d - 1) // d
# The next code block will only run in the case of "unbalanced" structures, e.g.
# if structure == "S(S(S)S", where there are unmatched right-parentheses.
cur_indexes = None
while len(downsample_info) > 0:
indexes, weights, x_orig = downsample_info.pop()
if cur_indexes is not None:
# keep only a subset of the indexes and weights, corresponding
# to later downsampling operations.
indexes = torch.gather(indexes, dim=1, index=cur_indexes)
weights = torch.gather(weights, dim=1, index=cur_indexes)
cur_indexes = indexes
x_lens = (weights != 0).sum(dim=1)
x_orig = convert_num_channels(x_orig, x.shape[-1])
x_orig, x = LearnedDownsamplingModule.apply_weights(x_orig, x, indexes, weights)
return x, x_lens
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
@ -1063,7 +1082,7 @@ class LearnedDownsamplingModule(nn.Module):
x_orig: (seq_len, batch_size, num_channels)
x: (seq_len_reduced, batch_size, num_channels)
indexes: (batch_size, seq_len_reduced), contains original frame indexes
weights: optional tensor
weights: optional tensor of shape (batch_size, seq_len_reduced)
Downsamples x via indexing with the indexes obtained from the
forward() function.
@ -1098,97 +1117,34 @@ class LearnedDownsamplingModule(nn.Module):
# add in x_orig in the frames that were not originally kept.
return ans + x_orig * orig_x_weight.t().unsqueeze(-1)
class DownsampledSubformerEncoder(nn.Module):
"""
DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced
frame rate, after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input.
"""
def __init__(self,
encoders: List[nn.Module],
input_num_channels: int,
downsample: int):
super(DownsampledSubformerEncoder, self).__init__()
if downsample != 1:
self.downsampler = LearnedDownsamplingModule(input_num_channels,
downsample)
self.encoders = nn.ModuleList(encoders)
self.out_combiner = BypassModule(self.embed_dim(),
straight_through_rate=0.0)
def embed_dim(self): # return output embed_dim which is max dim.
return max(e.embed_dim() for e in self.encoders)
def forward(self,
src: Tensor,
pos_emb: Tensor,
attn_offset: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
r"""Downsample, go through encoder, upsample.
@staticmethod
def apply_weights(x_orig: Tensor, x: Tensor, indexes: Tensor,
weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""
Downsamples x_orig to have the same shape as x and applies the weights,
returning interpolated x and downsampled x_orig. This is similar to
`upsample`, but is for the case where you don't want to keep the frames
that were not sampled.
Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim)
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_offset: the attention offset, added to scores for attention of shape
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means
masked position. May be None.
x_orig: (seq_len, batch_size, num_channels)
x: (seq_len_reduced, batch_size, num_channels)
indexes: (batch_size, seq_len_reduced), contains original frame indexes
weights: optional tensor of shape (batch_size, seq_len_reduced)
Returns: a Tensor with the same shape as src.
Returns (x_orig, x) after the downsampling and interpolation, of shapes
both (seq_len_reduced, batch_size, num_channels).
"""
src_orig = src
if hasattr(self, 'downsampler'):
indexes, weights, src = self.downsampler(src)
(seq_len, batch_size, num_channels) = x_orig.shape
weights = 1.0 if weights is None else weights.t().unsqueeze(-1)
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
# indexes now: (seq_len_reduced, batch_size, num_channels)
x_orig = torch.gather(x_orig, dim=0, index=indexes)
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
indexes,
weights)
outputs = [ src ]
x = x * weights + x_orig * (1.0 - weights)
for encoder in self.encoders:
src = encoder(
src,
pos_emb,
attn_offset=attn_offset,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
outputs.append(src)
def get_full_dim_output():
num_encoders = len(outputs)
output_dim = max(o.shape[-1] for o in outputs)
output_pieces = [ outputs[-1] ]
cur_dim = outputs[-1].shape[-1]
for i in range(num_encoders - 2, -1, -1):
d = outputs[i].shape[-1]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim
return torch.cat(output_pieces, dim=-1)
src = get_full_dim_output()
src_orig = convert_num_channels(src_orig, src.shape[-1])
if hasattr(self, 'downsampler'):
src = self.downsampler.upsample(src_orig, src, indexes, weights)
return self.out_combiner(src_orig, src)
return x_orig, x
@ -1881,6 +1837,7 @@ def _test_zipformer_main(causal: bool = False):
memory_dim = 100
c = Subformer(
structure = "S(S)S" if causal else "S(S(S",
encoder_dim=(64, 96, 64),
num_heads=(4, 4, 8),
causal=causal,