mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Support unbalanced structures
This commit is contained in:
parent
b85012aa0b
commit
0f27b14376
@ -289,6 +289,25 @@ class Subformer(EncoderInterface):
|
||||
# d = self.output_downsampling_factor
|
||||
# lengths = (x_lens + d - 1) // d
|
||||
|
||||
|
||||
# The next code block will only run in the case of "unbalanced" structures, e.g.
|
||||
# if structure == "S(S(S)S", where there are unmatched right-parentheses.
|
||||
cur_indexes = None
|
||||
while len(downsample_info) > 0:
|
||||
indexes, weights, x_orig = downsample_info.pop()
|
||||
if cur_indexes is not None:
|
||||
# keep only a subset of the indexes and weights, corresponding
|
||||
# to later downsampling operations.
|
||||
indexes = torch.gather(indexes, dim=1, index=cur_indexes)
|
||||
weights = torch.gather(weights, dim=1, index=cur_indexes)
|
||||
|
||||
cur_indexes = indexes
|
||||
|
||||
x_lens = (weights != 0).sum(dim=1)
|
||||
x_orig = convert_num_channels(x_orig, x.shape[-1])
|
||||
x_orig, x = LearnedDownsamplingModule.apply_weights(x_orig, x, indexes, weights)
|
||||
|
||||
|
||||
return x, x_lens
|
||||
|
||||
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
||||
@ -1063,7 +1082,7 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
x_orig: (seq_len, batch_size, num_channels)
|
||||
x: (seq_len_reduced, batch_size, num_channels)
|
||||
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
||||
weights: optional tensor
|
||||
weights: optional tensor of shape (batch_size, seq_len_reduced)
|
||||
|
||||
Downsamples x via indexing with the indexes obtained from the
|
||||
forward() function.
|
||||
@ -1098,97 +1117,34 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
# add in x_orig in the frames that were not originally kept.
|
||||
return ans + x_orig * orig_x_weight.t().unsqueeze(-1)
|
||||
|
||||
|
||||
class DownsampledSubformerEncoder(nn.Module):
|
||||
"""
|
||||
DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced
|
||||
frame rate, after convolutional downsampling, and then upsampled again at the output, and combined
|
||||
with the origin input, so that the output has the same shape as the input.
|
||||
"""
|
||||
def __init__(self,
|
||||
encoders: List[nn.Module],
|
||||
input_num_channels: int,
|
||||
downsample: int):
|
||||
super(DownsampledSubformerEncoder, self).__init__()
|
||||
if downsample != 1:
|
||||
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
||||
downsample)
|
||||
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
self.out_combiner = BypassModule(self.embed_dim(),
|
||||
straight_through_rate=0.0)
|
||||
|
||||
def embed_dim(self): # return output embed_dim which is max dim.
|
||||
return max(e.embed_dim() for e in self.encoders)
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
attn_offset: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""Downsample, go through encoder, upsample.
|
||||
@staticmethod
|
||||
def apply_weights(x_orig: Tensor, x: Tensor, indexes: Tensor,
|
||||
weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Downsamples x_orig to have the same shape as x and applies the weights,
|
||||
returning interpolated x and downsampled x_orig. This is similar to
|
||||
`upsample`, but is for the case where you don't want to keep the frames
|
||||
that were not sampled.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||
pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim)
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||
attn_offset: the attention offset, added to scores for attention of shape
|
||||
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||
attention), of shape (batch_size, memory_len); True means
|
||||
masked position. May be None.
|
||||
x_orig: (seq_len, batch_size, num_channels)
|
||||
x: (seq_len_reduced, batch_size, num_channels)
|
||||
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
||||
weights: optional tensor of shape (batch_size, seq_len_reduced)
|
||||
|
||||
Returns: a Tensor with the same shape as src.
|
||||
Returns (x_orig, x) after the downsampling and interpolation, of shapes
|
||||
both (seq_len_reduced, batch_size, num_channels).
|
||||
"""
|
||||
src_orig = src
|
||||
if hasattr(self, 'downsampler'):
|
||||
indexes, weights, src = self.downsampler(src)
|
||||
(seq_len, batch_size, num_channels) = x_orig.shape
|
||||
weights = 1.0 if weights is None else weights.t().unsqueeze(-1)
|
||||
|
||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
||||
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||
x_orig = torch.gather(x_orig, dim=0, index=indexes)
|
||||
|
||||
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
||||
indexes,
|
||||
weights)
|
||||
outputs = [ src ]
|
||||
x = x * weights + x_orig * (1.0 - weights)
|
||||
|
||||
for encoder in self.encoders:
|
||||
src = encoder(
|
||||
src,
|
||||
pos_emb,
|
||||
attn_offset=attn_offset,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
outputs.append(src)
|
||||
|
||||
def get_full_dim_output():
|
||||
num_encoders = len(outputs)
|
||||
output_dim = max(o.shape[-1] for o in outputs)
|
||||
output_pieces = [ outputs[-1] ]
|
||||
cur_dim = outputs[-1].shape[-1]
|
||||
for i in range(num_encoders - 2, -1, -1):
|
||||
d = outputs[i].shape[-1]
|
||||
if d > cur_dim:
|
||||
this_output = outputs[i]
|
||||
output_pieces.append(this_output[..., cur_dim:d])
|
||||
cur_dim = d
|
||||
assert cur_dim == output_dim
|
||||
return torch.cat(output_pieces, dim=-1)
|
||||
|
||||
src = get_full_dim_output()
|
||||
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
||||
|
||||
if hasattr(self, 'downsampler'):
|
||||
src = self.downsampler.upsample(src_orig, src, indexes, weights)
|
||||
|
||||
return self.out_combiner(src_orig, src)
|
||||
return x_orig, x
|
||||
|
||||
|
||||
|
||||
@ -1881,6 +1837,7 @@ def _test_zipformer_main(causal: bool = False):
|
||||
memory_dim = 100
|
||||
|
||||
c = Subformer(
|
||||
structure = "S(S)S" if causal else "S(S(S",
|
||||
encoder_dim=(64, 96, 64),
|
||||
num_heads=(4, 4, 8),
|
||||
causal=causal,
|
||||
|
Loading…
x
Reference in New Issue
Block a user