mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
split utterance over 512 frames into overlapping chunks
This commit is contained in:
parent
215541c7c5
commit
3dc33515c0
@ -1612,7 +1612,7 @@ def unfold(
|
|||||||
blocks: (kernel, batch_size * num_blocks, channel)
|
blocks: (kernel, batch_size * num_blocks, channel)
|
||||||
"""
|
"""
|
||||||
seq_len, batch_size, channel = x.size()
|
seq_len, batch_size, channel = x.size()
|
||||||
x = x.permute(1, 2, 0) # (B, D, T)
|
x = x.permute(1, 2, 0) # (batch_size, channel, seq_len)
|
||||||
|
|
||||||
x = nn.functional.pad(x, pad=(0, x_pad), value=0.0)
|
x = nn.functional.pad(x, pad=(0, x_pad), value=0.0)
|
||||||
|
|
||||||
@ -1629,6 +1629,34 @@ def unfold(
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def fold(
|
||||||
|
blocks: Tensor, seq_len: int, x_pad: int, num_blocks: int, kernel: int, stride: int, padding: int
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
blocks: (kernel, batch_size * num_blocks, channel)
|
||||||
|
Returns:
|
||||||
|
x: (seq_len, batch_size, channel)
|
||||||
|
"""
|
||||||
|
batch_size = blocks.size(1) // num_blocks
|
||||||
|
channel = blocks.size(2)
|
||||||
|
|
||||||
|
blocks = blocks.reshape(kernel, batch_size, num_blocks, channel)
|
||||||
|
blocks = blocks.permute(1, 3, 0, 2).reshape(batch_size, channel * kernel, num_blocks)
|
||||||
|
|
||||||
|
x = nn.functional.fold(
|
||||||
|
blocks,
|
||||||
|
output_size=(seq_len + x_pad, 1),
|
||||||
|
kernel_size=(kernel, 1),
|
||||||
|
padding=(padding, 0),
|
||||||
|
stride=(stride, 1),
|
||||||
|
)
|
||||||
|
x = x.squeeze(-1).permute(2, 0, 1)
|
||||||
|
x = x[:seq_len] # (seq_len, batch_size, channel)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _test_whiten():
|
def _test_whiten():
|
||||||
for proportion in [0.1, 0.5, 10.0]:
|
for proportion in [0.1, 0.5, 10.0]:
|
||||||
logging.info(f"_test_whiten(): proportion = {proportion}")
|
logging.info(f"_test_whiten(): proportion = {proportion}")
|
||||||
|
@ -39,6 +39,7 @@ from scaling import (
|
|||||||
FloatLike,
|
FloatLike,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
convert_num_channels,
|
convert_num_channels,
|
||||||
|
fold,
|
||||||
unfold,
|
unfold,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -679,10 +680,10 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
block_size: int = 0,
|
|
||||||
block_pad: int = 16,
|
|
||||||
chunk_size: int = -1,
|
chunk_size: int = -1,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
attn_offsets: Optional[Tensor] = None,
|
||||||
|
all_pad_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
@ -713,20 +714,12 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||||
|
|
||||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
if block_size == 0:
|
|
||||||
attn_weights = self.self_attn_weights(
|
attn_weights = self.self_attn_weights(
|
||||||
src,
|
src,
|
||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
attn_offsets=attn_offsets,
|
||||||
)
|
all_pad_mask=all_pad_mask,
|
||||||
else:
|
|
||||||
attn_weights = self.self_attn_weights.forward_block(
|
|
||||||
src,
|
|
||||||
pos_emb=pos_emb,
|
|
||||||
block_size=block_size,
|
|
||||||
block_pad=block_pad,
|
|
||||||
key_padding_mask=src_key_padding_mask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
@ -745,20 +738,12 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
||||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||||
|
|
||||||
if block_size == 0:
|
|
||||||
na = self.nonlin_attention(src, selected_attn_weights)
|
na = self.nonlin_attention(src, selected_attn_weights)
|
||||||
else:
|
|
||||||
na = self.nonlin_attention.forward_block(
|
|
||||||
src, selected_attn_weights, block_size=block_size, block_pad=block_pad)
|
|
||||||
na = self.balancer_na(na)
|
na = self.balancer_na(na)
|
||||||
|
|
||||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||||
|
|
||||||
if block_size == 0:
|
|
||||||
self_attn = self.self_attn1(src, attn_weights)
|
self_attn = self.self_attn1(src, attn_weights)
|
||||||
else:
|
|
||||||
self_attn = self.self_attn1.forward_block(
|
|
||||||
src, attn_weights, block_size=block_size, block_pad=block_pad)
|
|
||||||
|
|
||||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||||
|
|
||||||
@ -780,11 +765,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
# bypass in the middle of the layer.
|
# bypass in the middle of the layer.
|
||||||
src = self.bypass_mid(src_orig, src)
|
src = self.bypass_mid(src_orig, src)
|
||||||
|
|
||||||
if block_size == 0:
|
|
||||||
self_attn = self.self_attn2(src, attn_weights)
|
self_attn = self.self_attn2(src, attn_weights)
|
||||||
else:
|
|
||||||
self_attn = self.self_attn2.forward_block(
|
|
||||||
src, attn_weights, block_size=block_size, block_pad=block_pad)
|
|
||||||
|
|
||||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||||
|
|
||||||
@ -994,7 +975,7 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
by at every layer: if a Tensor, likely of shape (1, batch_size, embedding_dim)
|
||||||
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||||
True means masked position. May be None.
|
True means masked position. May be None.
|
||||||
@ -1003,20 +984,71 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
|
|
||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
seq_len = src.size(0)
|
seq_len, batch_size, channel = src.size()
|
||||||
max_block_size = self.max_block_size
|
max_block_size = self.max_block_size
|
||||||
block_pad = self.block_pad
|
block_pad = self.block_pad
|
||||||
|
|
||||||
if seq_len > max_block_size:
|
if seq_len > max_block_size:
|
||||||
|
# divide into blocks with overlaps
|
||||||
num_blocks = math.ceil(seq_len / max_block_size)
|
num_blocks = math.ceil(seq_len / max_block_size)
|
||||||
block_size = math.ceil(seq_len / num_blocks)
|
block_size = math.ceil(seq_len / num_blocks)
|
||||||
pos_emb = self.encoder_pos(src, rel_pos=block_size + block_pad)
|
pad_len = num_blocks * block_size - seq_len
|
||||||
# if __name__ == "__main__":
|
kernel_size = block_size + 2 * block_pad
|
||||||
if random.random() < 0.2:
|
if random.random() < 0.2 or __name__ == "__main__":
|
||||||
logging.info(f"seq_len={seq_len}, block_size={block_size}")
|
logging.info(f"seq_len={seq_len}, block_size={block_size}, pad_len={pad_len}")
|
||||||
|
|
||||||
|
# (block_size + 2 * block_pad, batch_size * num_blocks, channel)
|
||||||
|
src = unfold(
|
||||||
|
src, pad_len, num_blocks,
|
||||||
|
kernel=kernel_size, stride=block_size, padding=block_pad
|
||||||
|
)
|
||||||
|
|
||||||
|
# Used to mask out the padding positions
|
||||||
|
attn_offsets = torch.ones(batch_size, seq_len, device=src.device)
|
||||||
|
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
assert src_key_padding_mask.shape == (batch_size, seq_len), src_key_padding_mask.shape
|
||||||
|
attn_offsets = attn_offsets.masked_fill(src_key_padding_mask, 0.0) # 0 at padding positions
|
||||||
|
# (seq_len, batch, 1)
|
||||||
|
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(-1)
|
||||||
|
# (kernel_size, new_batch_size)
|
||||||
|
attn_offsets = unfold(
|
||||||
|
attn_offsets, pad_len, num_blocks,
|
||||||
|
kernel=kernel_size, stride=block_size, padding=block_pad,
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
# Used for the blocks are all padding
|
||||||
|
all_pad_mask = (attn_offsets.sum(dim=0, keepdim=True) == 0) # (1, new_batch_size)
|
||||||
|
all_pad_mask = all_pad_mask.unsqueeze(-1).unsqueeze(-1) # (1, new_batch_size, 1, 1)
|
||||||
|
|
||||||
|
# (new_batch_size, kernel_size)
|
||||||
|
src_key_padding_mask = (attn_offsets == 0).transpose(0, 1)
|
||||||
|
|
||||||
|
attn_offsets = 1 - attn_offsets # 1 at padding positions
|
||||||
|
attn_offsets[attn_offsets != 0] = -1000
|
||||||
|
|
||||||
|
# (1, new_batch_size, 1, kernel)
|
||||||
|
attn_offsets = attn_offsets.transpose(0, 1).unsqueeze(1).unsqueeze(0)
|
||||||
|
|
||||||
|
# feature_mask: (1, batch_size, channel)
|
||||||
|
if isinstance(feature_mask, Tensor):
|
||||||
|
feature_mask = feature_mask.unsqueeze(2).expand(-1, -1, num_blocks, -1)
|
||||||
|
# now (kernel_size, batch_size, num_blocks, channel)
|
||||||
|
feature_mask = feature_mask.reshape(1, batch_size * num_blocks, channel)
|
||||||
else:
|
else:
|
||||||
pos_emb = self.encoder_pos(src)
|
|
||||||
block_size = 0
|
block_size = 0
|
||||||
|
|
||||||
|
# Used to mask out the padding positions
|
||||||
|
attn_offsets = torch.zeros(batch_size, seq_len, device=src.device)
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
assert src_key_padding_mask.shape == (batch_size, seq_len), src_key_padding_mask.shape
|
||||||
|
attn_offsets = attn_offsets.masked_fill(src_key_padding_mask, -1000) # 0 at padding positions
|
||||||
|
# (1, batch_size, 1, seq_len)
|
||||||
|
attn_offsets = attn_offsets.unsqueeze(1).unsqueeze(0)
|
||||||
|
|
||||||
|
all_pad_mask = None
|
||||||
|
|
||||||
|
pos_emb = self.encoder_pos(src)
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
@ -1026,16 +1058,31 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
block_size=block_size,
|
|
||||||
block_pad=block_pad,
|
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
attn_offsets=attn_offsets,
|
||||||
|
all_pad_mask=all_pad_mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
|
if seq_len > max_block_size:
|
||||||
|
# overlap-and-add
|
||||||
|
output = fold(
|
||||||
|
output, seq_len, pad_len, num_blocks,
|
||||||
|
kernel=kernel_size, stride=block_size, padding=block_pad
|
||||||
|
) # (seq_len, batch_size, channel)
|
||||||
|
mask = torch.ones(
|
||||||
|
kernel_size, batch_size * num_blocks, 1, device=src.device,
|
||||||
|
)
|
||||||
|
mask = fold(
|
||||||
|
mask, seq_len, pad_len, num_blocks,
|
||||||
|
kernel=kernel_size, stride=block_size, padding=block_pad
|
||||||
|
) # (seq_len, batch_size, 1)
|
||||||
|
output = output / mask
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def streaming_forward(
|
def streaming_forward(
|
||||||
@ -1523,7 +1570,8 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
attn_offsets: Optional[Tensor] = None,
|
||||||
|
all_pad_mask: Optional[Tensor] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""
|
r"""
|
||||||
@ -1631,6 +1679,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
|
assert attn_mask is None
|
||||||
assert attn_mask.dtype == torch.bool
|
assert attn_mask.dtype == torch.bool
|
||||||
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
||||||
# all scores zero. It's important that this be large enough that exp(-1000)
|
# all scores zero. It's important that this be large enough that exp(-1000)
|
||||||
@ -1638,12 +1687,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# compares the final weights with zero.
|
# compares the final weights with zero.
|
||||||
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if attn_offsets is not None:
|
||||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
# attn_offsets: (1, batch_size, 1, seq_len)
|
||||||
attn_scores = attn_scores.masked_fill(
|
# or (1, new_batch_size, 1, kernel)
|
||||||
key_padding_mask.unsqueeze(1),
|
attn_scores = attn_scores + attn_offsets
|
||||||
-1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We use our own version of softmax, defined in scaling.py, which should
|
# We use our own version of softmax, defined in scaling.py, which should
|
||||||
# save a little of the memory used in backprop by, if we are in
|
# save a little of the memory used in backprop by, if we are in
|
||||||
@ -1651,6 +1698,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# half-precision output for backprop purposes.
|
# half-precision output for backprop purposes.
|
||||||
attn_weights = softmax(attn_scores, dim=-1)
|
attn_weights = softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
|
if all_pad_mask is not None:
|
||||||
|
# For the blocks are all padding
|
||||||
|
# all_pad_mask: (1, new_batch_size, 1, 1)
|
||||||
|
attn_weights = attn_weights.masked_fill(all_pad_mask, 0.0)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
elif random.random() < 0.001 and not self.training:
|
elif random.random() < 0.001 and not self.training:
|
||||||
@ -2586,13 +2638,13 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||||
downsampling_factor=(1, 2),
|
downsampling_factor=(1, 2),
|
||||||
max_block_size=14,
|
max_block_size=14,
|
||||||
block_pad=1,
|
block_pad=2,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
chunk_size=(4,) if causal else (-1,),
|
chunk_size=(4,) if causal else (-1,),
|
||||||
left_context_frames=(64,)
|
left_context_frames=(64,)
|
||||||
)
|
)
|
||||||
batch_size = 2
|
batch_size = 2
|
||||||
seq_len = 29
|
seq_len = 27
|
||||||
|
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
x = torch.randn(seq_len, batch_size, 64)
|
x = torch.randn(seq_len, batch_size, 64)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user