mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Support unbalanced structures
This commit is contained in:
parent
b85012aa0b
commit
0f27b14376
@ -289,6 +289,25 @@ class Subformer(EncoderInterface):
|
|||||||
# d = self.output_downsampling_factor
|
# d = self.output_downsampling_factor
|
||||||
# lengths = (x_lens + d - 1) // d
|
# lengths = (x_lens + d - 1) // d
|
||||||
|
|
||||||
|
|
||||||
|
# The next code block will only run in the case of "unbalanced" structures, e.g.
|
||||||
|
# if structure == "S(S(S)S", where there are unmatched right-parentheses.
|
||||||
|
cur_indexes = None
|
||||||
|
while len(downsample_info) > 0:
|
||||||
|
indexes, weights, x_orig = downsample_info.pop()
|
||||||
|
if cur_indexes is not None:
|
||||||
|
# keep only a subset of the indexes and weights, corresponding
|
||||||
|
# to later downsampling operations.
|
||||||
|
indexes = torch.gather(indexes, dim=1, index=cur_indexes)
|
||||||
|
weights = torch.gather(weights, dim=1, index=cur_indexes)
|
||||||
|
|
||||||
|
cur_indexes = indexes
|
||||||
|
|
||||||
|
x_lens = (weights != 0).sum(dim=1)
|
||||||
|
x_orig = convert_num_channels(x_orig, x.shape[-1])
|
||||||
|
x_orig, x = LearnedDownsamplingModule.apply_weights(x_orig, x, indexes, weights)
|
||||||
|
|
||||||
|
|
||||||
return x, x_lens
|
return x, x_lens
|
||||||
|
|
||||||
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
||||||
@ -1063,7 +1082,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
x_orig: (seq_len, batch_size, num_channels)
|
x_orig: (seq_len, batch_size, num_channels)
|
||||||
x: (seq_len_reduced, batch_size, num_channels)
|
x: (seq_len_reduced, batch_size, num_channels)
|
||||||
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
||||||
weights: optional tensor
|
weights: optional tensor of shape (batch_size, seq_len_reduced)
|
||||||
|
|
||||||
Downsamples x via indexing with the indexes obtained from the
|
Downsamples x via indexing with the indexes obtained from the
|
||||||
forward() function.
|
forward() function.
|
||||||
@ -1098,97 +1117,34 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
# add in x_orig in the frames that were not originally kept.
|
# add in x_orig in the frames that were not originally kept.
|
||||||
return ans + x_orig * orig_x_weight.t().unsqueeze(-1)
|
return ans + x_orig * orig_x_weight.t().unsqueeze(-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
class DownsampledSubformerEncoder(nn.Module):
|
def apply_weights(x_orig: Tensor, x: Tensor, indexes: Tensor,
|
||||||
"""
|
weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
|
||||||
DownsampledSubformerEncoder is a zipformer encoder stack possibly evaluated at a reduced
|
"""
|
||||||
frame rate, after convolutional downsampling, and then upsampled again at the output, and combined
|
Downsamples x_orig to have the same shape as x and applies the weights,
|
||||||
with the origin input, so that the output has the same shape as the input.
|
returning interpolated x and downsampled x_orig. This is similar to
|
||||||
"""
|
`upsample`, but is for the case where you don't want to keep the frames
|
||||||
def __init__(self,
|
that were not sampled.
|
||||||
encoders: List[nn.Module],
|
|
||||||
input_num_channels: int,
|
|
||||||
downsample: int):
|
|
||||||
super(DownsampledSubformerEncoder, self).__init__()
|
|
||||||
if downsample != 1:
|
|
||||||
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
|
||||||
downsample)
|
|
||||||
|
|
||||||
self.encoders = nn.ModuleList(encoders)
|
|
||||||
|
|
||||||
self.out_combiner = BypassModule(self.embed_dim(),
|
|
||||||
straight_through_rate=0.0)
|
|
||||||
|
|
||||||
def embed_dim(self): # return output embed_dim which is max dim.
|
|
||||||
return max(e.embed_dim() for e in self.encoders)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
src: Tensor,
|
|
||||||
pos_emb: Tensor,
|
|
||||||
attn_offset: Tensor,
|
|
||||||
feature_mask: Union[Tensor, float] = 1.0,
|
|
||||||
memory: Optional[Tensor] = None,
|
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
r"""Downsample, go through encoder, upsample.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
x_orig: (seq_len, batch_size, num_channels)
|
||||||
pos_emb: the positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim)
|
x: (seq_len_reduced, batch_size, num_channels)
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
||||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
weights: optional tensor of shape (batch_size, seq_len_reduced)
|
||||||
attn_offset: the attention offset, added to scores for attention of shape
|
|
||||||
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
|
||||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
|
||||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
|
||||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
|
||||||
attention), of shape (batch_size, memory_len); True means
|
|
||||||
masked position. May be None.
|
|
||||||
|
|
||||||
Returns: a Tensor with the same shape as src.
|
Returns (x_orig, x) after the downsampling and interpolation, of shapes
|
||||||
|
both (seq_len_reduced, batch_size, num_channels).
|
||||||
"""
|
"""
|
||||||
src_orig = src
|
(seq_len, batch_size, num_channels) = x_orig.shape
|
||||||
if hasattr(self, 'downsampler'):
|
weights = 1.0 if weights is None else weights.t().unsqueeze(-1)
|
||||||
indexes, weights, src = self.downsampler(src)
|
|
||||||
|
|
||||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
||||||
|
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||||
|
x_orig = torch.gather(x_orig, dim=0, index=indexes)
|
||||||
|
|
||||||
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
x = x * weights + x_orig * (1.0 - weights)
|
||||||
indexes,
|
|
||||||
weights)
|
|
||||||
outputs = [ src ]
|
|
||||||
|
|
||||||
for encoder in self.encoders:
|
return x_orig, x
|
||||||
src = encoder(
|
|
||||||
src,
|
|
||||||
pos_emb,
|
|
||||||
attn_offset=attn_offset,
|
|
||||||
memory=memory,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
)
|
|
||||||
outputs.append(src)
|
|
||||||
|
|
||||||
def get_full_dim_output():
|
|
||||||
num_encoders = len(outputs)
|
|
||||||
output_dim = max(o.shape[-1] for o in outputs)
|
|
||||||
output_pieces = [ outputs[-1] ]
|
|
||||||
cur_dim = outputs[-1].shape[-1]
|
|
||||||
for i in range(num_encoders - 2, -1, -1):
|
|
||||||
d = outputs[i].shape[-1]
|
|
||||||
if d > cur_dim:
|
|
||||||
this_output = outputs[i]
|
|
||||||
output_pieces.append(this_output[..., cur_dim:d])
|
|
||||||
cur_dim = d
|
|
||||||
assert cur_dim == output_dim
|
|
||||||
return torch.cat(output_pieces, dim=-1)
|
|
||||||
|
|
||||||
src = get_full_dim_output()
|
|
||||||
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
|
||||||
|
|
||||||
if hasattr(self, 'downsampler'):
|
|
||||||
src = self.downsampler.upsample(src_orig, src, indexes, weights)
|
|
||||||
|
|
||||||
return self.out_combiner(src_orig, src)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1881,6 +1837,7 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
memory_dim = 100
|
memory_dim = 100
|
||||||
|
|
||||||
c = Subformer(
|
c = Subformer(
|
||||||
|
structure = "S(S)S" if causal else "S(S(S",
|
||||||
encoder_dim=(64, 96, 64),
|
encoder_dim=(64, 96, 64),
|
||||||
num_heads=(4, 4, 8),
|
num_heads=(4, 4, 8),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user