mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
rename zipformer to subformer, remove some things that won't be used.
This commit is contained in:
parent
2e4b27a1c8
commit
5c470fe397
@ -27,7 +27,6 @@ from scaling import (
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
Dropout2,
|
||||
ChunkCausalDepthwiseConv1d,
|
||||
ActivationDropoutAndLinear,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
Whiten,
|
||||
@ -42,7 +41,7 @@ from scaling import (
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Zipformer2(EncoderInterface):
|
||||
class Subformer2(EncoderInterface):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@ -70,7 +69,6 @@ class Zipformer2(EncoderInterface):
|
||||
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
|
||||
Must be at least 4.
|
||||
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
|
||||
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
|
||||
|
||||
pos_dim (int): the dimension of each positional-encoding vector prior to projection,
|
||||
e.g. 128.
|
||||
@ -83,15 +81,9 @@ class Zipformer2(EncoderInterface):
|
||||
slightly slower and use more memory. Enables use of the chunk_size and
|
||||
left_context_chunks options in forward(), which simulates streaming
|
||||
decoding.
|
||||
chunk_size: (list of int): only set this to other than [-1] if causal;
|
||||
the chunk size will be randomly chosen from this list. -1 means no chunking.
|
||||
left_context_frames: (list of int): determines the number of left-
|
||||
context chunks for causal training; will be rounded to a number of
|
||||
chunks. Must not be less than cnn_module_kernel (after factoring in
|
||||
rounding and downsampling); an error will be thrown if this is violated.
|
||||
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
|
||||
passed into the zipformer (e.g. this might be the output of another
|
||||
Zipformer used to create embedding vectors.)
|
||||
Subformer used to create embedding vectors.)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -105,16 +97,13 @@ class Zipformer2(EncoderInterface):
|
||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||
num_heads: Union[int, Tuple[int]] = 8,
|
||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||
memory_dim: int = -1,
|
||||
pos_dim: int = 192,
|
||||
dropout: FloatLike = None, # see code below for default
|
||||
warmup_batches: float = 4000.0,
|
||||
causal: bool = False,
|
||||
chunk_size: Tuple[int] = (-1,),
|
||||
left_context_frames: Tuple[int] = (-1,),
|
||||
) -> None:
|
||||
super(Zipformer2, self).__init__()
|
||||
super(Subformer2, self).__init__()
|
||||
|
||||
if dropout is None:
|
||||
dropout = ScheduledFloat((0.0, 0.3),
|
||||
@ -141,22 +130,17 @@ class Zipformer2(EncoderInterface):
|
||||
pos_head_dim = _to_tuple(pos_head_dim)
|
||||
num_heads = _to_tuple(num_heads)
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||
|
||||
self.causal = causal
|
||||
self.chunk_size = chunk_size
|
||||
self.left_context_frames = left_context_frames
|
||||
|
||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||
assert u <= d
|
||||
|
||||
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
||||
# each one will be Subformer2Encoder or DownsampledSubformer2Encoder
|
||||
encoders = []
|
||||
|
||||
num_encoders = len(downsampling_factor)
|
||||
for i in range(num_encoders):
|
||||
|
||||
encoder_layer = Zipformer2EncoderLayer(
|
||||
encoder_layer = Subformer2EncoderLayer(
|
||||
embed_dim=encoder_dim[i],
|
||||
pos_dim=pos_dim,
|
||||
num_heads=num_heads[i],
|
||||
@ -166,13 +150,12 @@ class Zipformer2(EncoderInterface):
|
||||
feedforward_dim=feedforward_dim[i],
|
||||
memory_dim=memory_dim,
|
||||
dropout=dropout,
|
||||
cnn_module_kernel=cnn_module_kernel[i],
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||
# layer learn something. Then we start to warm up the other encoders.
|
||||
encoder = Zipformer2Encoder(
|
||||
encoder = Subformer2Encoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers[i],
|
||||
pos_dim=pos_dim,
|
||||
@ -183,7 +166,7 @@ class Zipformer2(EncoderInterface):
|
||||
)
|
||||
|
||||
if downsampling_factor[i] != 1:
|
||||
encoder = DownsampledZipformer2Encoder(
|
||||
encoder = DownsampledSubformer2Encoder(
|
||||
encoder,
|
||||
dim=encoder_dim[i],
|
||||
downsample=downsampling_factor[i],
|
||||
@ -257,24 +240,6 @@ class Zipformer2(EncoderInterface):
|
||||
return feature_masks
|
||||
|
||||
|
||||
def get_chunk_info(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns chunk_size and left_context_chunks.
|
||||
"""
|
||||
if not self.causal:
|
||||
return -1, -1
|
||||
chunk_size = random.choice(self.chunk_size)
|
||||
if chunk_size == -1:
|
||||
left_context_chunks = -1
|
||||
else:
|
||||
left_context_frames = random.choice(self.left_context_frames)
|
||||
# Note: in Python, -1 // n == -1 for n > 0
|
||||
left_context_chunks = left_context_frames // chunk_size
|
||||
if left_context_chunks == 0:
|
||||
left_context_chunks = 1
|
||||
return chunk_size, left_context_chunks
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -307,9 +272,7 @@ class Zipformer2(EncoderInterface):
|
||||
outputs = []
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||
|
||||
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
||||
attn_mask = self._get_attn_mask(x)
|
||||
|
||||
if self.training and memory is not None:
|
||||
batch_size = x.shape[1]
|
||||
@ -361,45 +324,31 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
return x, lengths
|
||||
|
||||
def _get_attn_mask(self, x: Tensor,
|
||||
chunk_size: int,
|
||||
left_context_chunks: int
|
||||
) -> Optional[Tensor]:
|
||||
def _get_attn_mask(self, x: Tensor) -> Optional[Tensor]:
|
||||
"""
|
||||
Return None if chunk_size == -1, else return attention mask of shape
|
||||
Return None if not self.causal is false else return attention mask of shape
|
||||
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
|
||||
means a masked position.
|
||||
Args:
|
||||
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
||||
chunk_size: chunk size, must divide
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
if not self.causal:
|
||||
return None
|
||||
assert all(chunk_size % d == 0 for d in self.downsampling_factor)
|
||||
if left_context_chunks >= 0:
|
||||
num_encoders = len(self.encoder_dim)
|
||||
assert all (chunk_size * left_context_chunks >=
|
||||
(self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
|
||||
for i in range(num_encoders))
|
||||
else:
|
||||
left_context_chunks = 1000000
|
||||
|
||||
seq_len = x.shape[0]
|
||||
|
||||
# t is frame index, shape (seq_len,)
|
||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||
# c is chunk index for each frame, shape (seq_len,)
|
||||
c = t // chunk_size
|
||||
src_c = c
|
||||
tgt_c = c.unsqueeze(-1)
|
||||
|
||||
attn_mask = torch.logical_or(src_c > tgt_c,
|
||||
src_c < tgt_c - left_context_chunks)
|
||||
if __name__ == "__main__":
|
||||
logging.info(f"attn_mask = {attn_mask}")
|
||||
attn_mask = (src_c > tgt_c)
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
|
||||
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||
return ScheduledFloat((0.0, x),
|
||||
(20000.0, ratio * x),
|
||||
@ -410,17 +359,16 @@ def _balancer_schedule(min_prob: float):
|
||||
|
||||
|
||||
|
||||
class Zipformer2EncoderLayer(nn.Module):
|
||||
class Subformer2EncoderLayer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
embed_dim: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
feedforward_dim: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = encoder_layer(src, pos_emb)
|
||||
@ -435,7 +383,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
value_head_dim: int,
|
||||
feedforward_dim: int,
|
||||
dropout: FloatLike = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
causal: bool = False,
|
||||
memory_dim: int = -1,
|
||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
@ -445,7 +392,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
||||
) -> None:
|
||||
super(Zipformer2EncoderLayer, self).__init__()
|
||||
super(Subformer2EncoderLayer, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||
@ -509,14 +456,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self.nonlin_attention = NonlinAttention(embed_dim,
|
||||
hidden_channels=3 * embed_dim // 4)
|
||||
|
||||
self.conv_module1 = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
causal=causal)
|
||||
|
||||
self.conv_module2 = ConvolutionModule(embed_dim,
|
||||
cnn_module_kernel,
|
||||
causal=causal)
|
||||
|
||||
|
||||
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||
|
||||
@ -682,10 +621,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights),
|
||||
attention_skip_rate)
|
||||
|
||||
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
float(self.conv_skip_rate))
|
||||
|
||||
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
|
||||
float(self.ff2_skip_rate))
|
||||
|
||||
@ -701,10 +636,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights),
|
||||
attention_skip_rate)
|
||||
|
||||
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
float(self.conv_skip_rate))
|
||||
|
||||
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
|
||||
float(self.ff3_skip_rate))
|
||||
|
||||
@ -718,17 +649,17 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
return src
|
||||
|
||||
class Zipformer2Encoder(nn.Module):
|
||||
r"""Zipformer2Encoder is a stack of N encoder layers
|
||||
class Subformer2Encoder(nn.Module):
|
||||
r"""Subformer2Encoder is a stack of N encoder layers
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
|
||||
encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
pos_dim: the dimension for the relative positional encoding
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||
>>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
|
||||
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||
>>> zipformer_encoder = Subformer2Encoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = zipformer_encoder(src)
|
||||
"""
|
||||
@ -874,9 +805,9 @@ class BypassModule(nn.Module):
|
||||
|
||||
|
||||
|
||||
class DownsampledZipformer2Encoder(nn.Module):
|
||||
class DownsampledSubformer2Encoder(nn.Module):
|
||||
r"""
|
||||
DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
||||
with the origin input, so that the output has the same shape as the input.
|
||||
"""
|
||||
@ -885,7 +816,7 @@ class DownsampledZipformer2Encoder(nn.Module):
|
||||
dim: int,
|
||||
downsample: int,
|
||||
dropout: FloatLike):
|
||||
super(DownsampledZipformer2Encoder, self).__init__()
|
||||
super(DownsampledSubformer2Encoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsample = SimpleDownsample(dim,
|
||||
downsample, dropout)
|
||||
@ -1577,7 +1508,7 @@ class MultiheadAttentionWeights(nn.Module):
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Zipformer2 model.
|
||||
"""Feedforward module in Subformer2 model.
|
||||
"""
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
@ -1718,137 +1649,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
return x
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Zipformer2 model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, causal: bool,
|
||||
) -> None:
|
||||
"""Construct a ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
bottleneck_dim = channels
|
||||
self.causal = causal
|
||||
|
||||
self.in_proj = nn.Linear(
|
||||
channels, 2 * bottleneck_dim,
|
||||
)
|
||||
# the gradients on in_proj are a little noisy, likely to do with the
|
||||
# sigmoid in glu.
|
||||
|
||||
# after in_proj we put x through a gated linear unit (nn.functional.glu).
|
||||
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
||||
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
||||
# between 50 and 100 for different channels. This will cause very peaky and
|
||||
# sparse derivatives for the sigmoid gating function, which will tend to make
|
||||
# the loss function not learn effectively. (for most layers the average absolute values
|
||||
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
|
||||
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
|
||||
# layers, which likely breaks down as 0.5 for the "linear" half and
|
||||
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
|
||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
||||
# it will be in a better position to start learning something, i.e. to latch onto
|
||||
# the correct range.
|
||||
self.balancer1 = Balancer(
|
||||
bottleneck_dim, channel_dim=-1,
|
||||
min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
|
||||
max_positive=1.0,
|
||||
min_abs=1.5,
|
||||
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
|
||||
)
|
||||
|
||||
self.activation1 = Identity() # for diagnostics
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.activation2 = Identity() # for diagnostics
|
||||
|
||||
assert kernel_size % 2 == 1
|
||||
|
||||
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
|
||||
channels=bottleneck_dim,
|
||||
kernel_size=kernel_size) if causal else nn.Conv1d(
|
||||
in_channels=bottleneck_dim,
|
||||
out_channels=bottleneck_dim,
|
||||
groups=bottleneck_dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
|
||||
self.balancer2 = Balancer(
|
||||
bottleneck_dim, channel_dim=1,
|
||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||
max_positive=1.0,
|
||||
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
||||
max_abs=10.0,
|
||||
)
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.out_proj = ActivationDropoutAndLinear(
|
||||
bottleneck_dim, channels, activation='SwooshR',
|
||||
dropout_p=0.0, initial_scale=0.05,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
chunk_size: int = -1,
|
||||
) -> Tensor:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
||||
(batch, #time), contains True in masked positions.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
|
||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||
|
||||
x, s = x.chunk(2, dim=-1)
|
||||
s = self.balancer1(s)
|
||||
s = self.sigmoid(s)
|
||||
x = self.activation1(x) # identity.
|
||||
x = x * s
|
||||
x = self.activation2(x) # identity
|
||||
|
||||
# (time, batch, channels)
|
||||
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
if chunk_size >= 0:
|
||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
||||
else:
|
||||
x = self.depthwise_conv(x)
|
||||
|
||||
x = self.balancer2(x)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
x = self.whiten(x) # (time, batch, channels)
|
||||
x = self.out_proj(x) # (time, batch, channels)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ScalarMultiply(nn.Module):
|
||||
def __init__(self, scale: float):
|
||||
@ -1865,7 +1665,7 @@ def _test_zipformer_main(causal: bool = False):
|
||||
# Just make sure the forward pass runs.
|
||||
memory_dim = 100
|
||||
|
||||
c = Zipformer2(
|
||||
c = Subformer2(
|
||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||
causal=causal,
|
||||
chunk_size=(4,) if causal else (-1,),
|
||||
|
||||
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/zipformer2/zipformer.py
|
||||
Loading…
x
Reference in New Issue
Block a user