rename zipformer to subformer, remove some things that won't be used.
This commit is contained in:
parent
2e4b27a1c8
commit
5c470fe397
@ -27,7 +27,6 @@ from scaling import (
|
|||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
Dropout2,
|
Dropout2,
|
||||||
ChunkCausalDepthwiseConv1d,
|
|
||||||
ActivationDropoutAndLinear,
|
ActivationDropoutAndLinear,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
Whiten,
|
Whiten,
|
||||||
@ -42,7 +41,7 @@ from scaling import (
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class Zipformer2(EncoderInterface):
|
class Subformer2(EncoderInterface):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
||||||
@ -70,7 +69,6 @@ class Zipformer2(EncoderInterface):
|
|||||||
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
|
num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism.
|
||||||
Must be at least 4.
|
Must be at least 4.
|
||||||
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
|
feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules
|
||||||
cnn_module_kernel (int or Tuple[int])): Kernel size of convolution module
|
|
||||||
|
|
||||||
pos_dim (int): the dimension of each positional-encoding vector prior to projection,
|
pos_dim (int): the dimension of each positional-encoding vector prior to projection,
|
||||||
e.g. 128.
|
e.g. 128.
|
||||||
@ -83,15 +81,9 @@ class Zipformer2(EncoderInterface):
|
|||||||
slightly slower and use more memory. Enables use of the chunk_size and
|
slightly slower and use more memory. Enables use of the chunk_size and
|
||||||
left_context_chunks options in forward(), which simulates streaming
|
left_context_chunks options in forward(), which simulates streaming
|
||||||
decoding.
|
decoding.
|
||||||
chunk_size: (list of int): only set this to other than [-1] if causal;
|
|
||||||
the chunk size will be randomly chosen from this list. -1 means no chunking.
|
|
||||||
left_context_frames: (list of int): determines the number of left-
|
|
||||||
context chunks for causal training; will be rounded to a number of
|
|
||||||
chunks. Must not be less than cnn_module_kernel (after factoring in
|
|
||||||
rounding and downsampling); an error will be thrown if this is violated.
|
|
||||||
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
|
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
|
||||||
passed into the zipformer (e.g. this might be the output of another
|
passed into the zipformer (e.g. this might be the output of another
|
||||||
Zipformer used to create embedding vectors.)
|
Subformer used to create embedding vectors.)
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -105,16 +97,13 @@ class Zipformer2(EncoderInterface):
|
|||||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||||
num_heads: Union[int, Tuple[int]] = 8,
|
num_heads: Union[int, Tuple[int]] = 8,
|
||||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
|
||||||
memory_dim: int = -1,
|
memory_dim: int = -1,
|
||||||
pos_dim: int = 192,
|
pos_dim: int = 192,
|
||||||
dropout: FloatLike = None, # see code below for default
|
dropout: FloatLike = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
chunk_size: Tuple[int] = (-1,),
|
|
||||||
left_context_frames: Tuple[int] = (-1,),
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer2, self).__init__()
|
super(Subformer2, self).__init__()
|
||||||
|
|
||||||
if dropout is None:
|
if dropout is None:
|
||||||
dropout = ScheduledFloat((0.0, 0.3),
|
dropout = ScheduledFloat((0.0, 0.3),
|
||||||
@ -141,22 +130,17 @@ class Zipformer2(EncoderInterface):
|
|||||||
pos_head_dim = _to_tuple(pos_head_dim)
|
pos_head_dim = _to_tuple(pos_head_dim)
|
||||||
num_heads = _to_tuple(num_heads)
|
num_heads = _to_tuple(num_heads)
|
||||||
feedforward_dim = _to_tuple(feedforward_dim)
|
feedforward_dim = _to_tuple(feedforward_dim)
|
||||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
|
||||||
|
|
||||||
self.causal = causal
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.left_context_frames = left_context_frames
|
|
||||||
|
|
||||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
|
|
||||||
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
# each one will be Subformer2Encoder or DownsampledSubformer2Encoder
|
||||||
encoders = []
|
encoders = []
|
||||||
|
|
||||||
num_encoders = len(downsampling_factor)
|
num_encoders = len(downsampling_factor)
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
|
|
||||||
encoder_layer = Zipformer2EncoderLayer(
|
encoder_layer = Subformer2EncoderLayer(
|
||||||
embed_dim=encoder_dim[i],
|
embed_dim=encoder_dim[i],
|
||||||
pos_dim=pos_dim,
|
pos_dim=pos_dim,
|
||||||
num_heads=num_heads[i],
|
num_heads=num_heads[i],
|
||||||
@ -166,13 +150,12 @@ class Zipformer2(EncoderInterface):
|
|||||||
feedforward_dim=feedforward_dim[i],
|
feedforward_dim=feedforward_dim[i],
|
||||||
memory_dim=memory_dim,
|
memory_dim=memory_dim,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
cnn_module_kernel=cnn_module_kernel[i],
|
|
||||||
causal=causal,
|
causal=causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||||
# layer learn something. Then we start to warm up the other encoders.
|
# layer learn something. Then we start to warm up the other encoders.
|
||||||
encoder = Zipformer2Encoder(
|
encoder = Subformer2Encoder(
|
||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers[i],
|
num_encoder_layers[i],
|
||||||
pos_dim=pos_dim,
|
pos_dim=pos_dim,
|
||||||
@ -183,7 +166,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if downsampling_factor[i] != 1:
|
if downsampling_factor[i] != 1:
|
||||||
encoder = DownsampledZipformer2Encoder(
|
encoder = DownsampledSubformer2Encoder(
|
||||||
encoder,
|
encoder,
|
||||||
dim=encoder_dim[i],
|
dim=encoder_dim[i],
|
||||||
downsample=downsampling_factor[i],
|
downsample=downsampling_factor[i],
|
||||||
@ -257,24 +240,6 @@ class Zipformer2(EncoderInterface):
|
|||||||
return feature_masks
|
return feature_masks
|
||||||
|
|
||||||
|
|
||||||
def get_chunk_info(self) -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Returns chunk_size and left_context_chunks.
|
|
||||||
"""
|
|
||||||
if not self.causal:
|
|
||||||
return -1, -1
|
|
||||||
chunk_size = random.choice(self.chunk_size)
|
|
||||||
if chunk_size == -1:
|
|
||||||
left_context_chunks = -1
|
|
||||||
else:
|
|
||||||
left_context_frames = random.choice(self.left_context_frames)
|
|
||||||
# Note: in Python, -1 // n == -1 for n > 0
|
|
||||||
left_context_chunks = left_context_frames // chunk_size
|
|
||||||
if left_context_chunks == 0:
|
|
||||||
left_context_chunks = 1
|
|
||||||
return chunk_size, left_context_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -307,9 +272,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
outputs = []
|
outputs = []
|
||||||
feature_masks = self.get_feature_masks(x)
|
feature_masks = self.get_feature_masks(x)
|
||||||
|
|
||||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
attn_mask = self._get_attn_mask(x)
|
||||||
|
|
||||||
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
|
||||||
|
|
||||||
if self.training and memory is not None:
|
if self.training and memory is not None:
|
||||||
batch_size = x.shape[1]
|
batch_size = x.shape[1]
|
||||||
@ -361,45 +324,31 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
def _get_attn_mask(self, x: Tensor,
|
def _get_attn_mask(self, x: Tensor) -> Optional[Tensor]:
|
||||||
chunk_size: int,
|
|
||||||
left_context_chunks: int
|
|
||||||
) -> Optional[Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Return None if chunk_size == -1, else return attention mask of shape
|
Return None if not self.causal is false else return attention mask of shape
|
||||||
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
|
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
|
||||||
means a masked position.
|
means a masked position.
|
||||||
Args:
|
Args:
|
||||||
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
||||||
chunk_size: chunk size, must divide
|
chunk_size: chunk size, must divide
|
||||||
"""
|
"""
|
||||||
if chunk_size <= 0:
|
if not self.causal:
|
||||||
return None
|
return None
|
||||||
assert all(chunk_size % d == 0 for d in self.downsampling_factor)
|
|
||||||
if left_context_chunks >= 0:
|
|
||||||
num_encoders = len(self.encoder_dim)
|
|
||||||
assert all (chunk_size * left_context_chunks >=
|
|
||||||
(self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
|
|
||||||
for i in range(num_encoders))
|
|
||||||
else:
|
|
||||||
left_context_chunks = 1000000
|
|
||||||
|
|
||||||
seq_len = x.shape[0]
|
seq_len = x.shape[0]
|
||||||
|
|
||||||
# t is frame index, shape (seq_len,)
|
# t is frame index, shape (seq_len,)
|
||||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||||
# c is chunk index for each frame, shape (seq_len,)
|
|
||||||
c = t // chunk_size
|
|
||||||
src_c = c
|
src_c = c
|
||||||
tgt_c = c.unsqueeze(-1)
|
tgt_c = c.unsqueeze(-1)
|
||||||
|
|
||||||
attn_mask = torch.logical_or(src_c > tgt_c,
|
attn_mask = (src_c > tgt_c)
|
||||||
src_c < tgt_c - left_context_chunks)
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.info(f"attn_mask = {attn_mask}")
|
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||||
return ScheduledFloat((0.0, x),
|
return ScheduledFloat((0.0, x),
|
||||||
(20000.0, ratio * x),
|
(20000.0, ratio * x),
|
||||||
@ -410,17 +359,16 @@ def _balancer_schedule(min_prob: float):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Zipformer2EncoderLayer(nn.Module):
|
class Subformer2EncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
embed_dim: the number of expected features in the input (required).
|
embed_dim: the number of expected features in the input (required).
|
||||||
nhead: the number of heads in the multiheadattention models (required).
|
nhead: the number of heads in the multiheadattention models (required).
|
||||||
feedforward_dim: the dimension of the feedforward network model (default=2048).
|
feedforward_dim: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> pos_emb = torch.rand(32, 19, 512)
|
>>> pos_emb = torch.rand(32, 19, 512)
|
||||||
>>> out = encoder_layer(src, pos_emb)
|
>>> out = encoder_layer(src, pos_emb)
|
||||||
@ -435,7 +383,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
value_head_dim: int,
|
value_head_dim: int,
|
||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: FloatLike = 0.1,
|
dropout: FloatLike = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
memory_dim: int = -1,
|
memory_dim: int = -1,
|
||||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
@ -445,7 +392,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||||
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer2EncoderLayer, self).__init__()
|
super(Subformer2EncoderLayer, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||||
@ -509,14 +456,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
self.nonlin_attention = NonlinAttention(embed_dim,
|
self.nonlin_attention = NonlinAttention(embed_dim,
|
||||||
hidden_channels=3 * embed_dim // 4)
|
hidden_channels=3 * embed_dim // 4)
|
||||||
|
|
||||||
self.conv_module1 = ConvolutionModule(embed_dim,
|
|
||||||
cnn_module_kernel,
|
|
||||||
causal=causal)
|
|
||||||
|
|
||||||
self.conv_module2 = ConvolutionModule(embed_dim,
|
|
||||||
cnn_module_kernel,
|
|
||||||
causal=causal)
|
|
||||||
|
|
||||||
|
|
||||||
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||||
|
|
||||||
@ -682,10 +621,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights),
|
src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights),
|
||||||
attention_skip_rate)
|
attention_skip_rate)
|
||||||
|
|
||||||
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
|
|
||||||
src_key_padding_mask=src_key_padding_mask),
|
|
||||||
float(self.conv_skip_rate))
|
|
||||||
|
|
||||||
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
|
src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)),
|
||||||
float(self.ff2_skip_rate))
|
float(self.ff2_skip_rate))
|
||||||
|
|
||||||
@ -701,10 +636,6 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights),
|
src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights),
|
||||||
attention_skip_rate)
|
attention_skip_rate)
|
||||||
|
|
||||||
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
|
|
||||||
src_key_padding_mask=src_key_padding_mask),
|
|
||||||
float(self.conv_skip_rate))
|
|
||||||
|
|
||||||
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
|
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
|
||||||
float(self.ff3_skip_rate))
|
float(self.ff3_skip_rate))
|
||||||
|
|
||||||
@ -718,17 +649,17 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
class Zipformer2Encoder(nn.Module):
|
class Subformer2Encoder(nn.Module):
|
||||||
r"""Zipformer2Encoder is a stack of N encoder layers
|
r"""Subformer2Encoder is a stack of N encoder layers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
|
encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
pos_dim: the dimension for the relative positional encoding
|
pos_dim: the dimension for the relative positional encoding
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||||
>>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
|
>>> zipformer_encoder = Subformer2Encoder(encoder_layer, num_layers=6)
|
||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> out = zipformer_encoder(src)
|
>>> out = zipformer_encoder(src)
|
||||||
"""
|
"""
|
||||||
@ -874,9 +805,9 @@ class BypassModule(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class DownsampledZipformer2Encoder(nn.Module):
|
class DownsampledSubformer2Encoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
after convolutional downsampling, and then upsampled again at the output, and combined
|
||||||
with the origin input, so that the output has the same shape as the input.
|
with the origin input, so that the output has the same shape as the input.
|
||||||
"""
|
"""
|
||||||
@ -885,7 +816,7 @@ class DownsampledZipformer2Encoder(nn.Module):
|
|||||||
dim: int,
|
dim: int,
|
||||||
downsample: int,
|
downsample: int,
|
||||||
dropout: FloatLike):
|
dropout: FloatLike):
|
||||||
super(DownsampledZipformer2Encoder, self).__init__()
|
super(DownsampledSubformer2Encoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsample = SimpleDownsample(dim,
|
self.downsample = SimpleDownsample(dim,
|
||||||
downsample, dropout)
|
downsample, dropout)
|
||||||
@ -1577,7 +1508,7 @@ class MultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedforwardModule(nn.Module):
|
class FeedforwardModule(nn.Module):
|
||||||
"""Feedforward module in Zipformer2 model.
|
"""Feedforward module in Subformer2 model.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
@ -1718,137 +1649,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
|
||||||
"""ConvolutionModule in Zipformer2 model.
|
|
||||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
|
|
||||||
|
|
||||||
Args:
|
|
||||||
channels (int): The number of channels of conv layers.
|
|
||||||
kernel_size (int): Kernerl size of conv layers.
|
|
||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
|
||||||
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self, channels: int, kernel_size: int, causal: bool,
|
|
||||||
) -> None:
|
|
||||||
"""Construct a ConvolutionModule object."""
|
|
||||||
super(ConvolutionModule, self).__init__()
|
|
||||||
# kernerl_size should be a odd number for 'SAME' padding
|
|
||||||
assert (kernel_size - 1) % 2 == 0
|
|
||||||
|
|
||||||
bottleneck_dim = channels
|
|
||||||
self.causal = causal
|
|
||||||
|
|
||||||
self.in_proj = nn.Linear(
|
|
||||||
channels, 2 * bottleneck_dim,
|
|
||||||
)
|
|
||||||
# the gradients on in_proj are a little noisy, likely to do with the
|
|
||||||
# sigmoid in glu.
|
|
||||||
|
|
||||||
# after in_proj we put x through a gated linear unit (nn.functional.glu).
|
|
||||||
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
|
||||||
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
|
||||||
# between 50 and 100 for different channels. This will cause very peaky and
|
|
||||||
# sparse derivatives for the sigmoid gating function, which will tend to make
|
|
||||||
# the loss function not learn effectively. (for most layers the average absolute values
|
|
||||||
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
|
|
||||||
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
|
|
||||||
# layers, which likely breaks down as 0.5 for the "linear" half and
|
|
||||||
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
|
|
||||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
|
||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
|
||||||
# the correct range.
|
|
||||||
self.balancer1 = Balancer(
|
|
||||||
bottleneck_dim, channel_dim=-1,
|
|
||||||
min_positive=ScheduledFloat((0.0, 0.05), (8000.0, 0.025)),
|
|
||||||
max_positive=1.0,
|
|
||||||
min_abs=1.5,
|
|
||||||
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.activation1 = Identity() # for diagnostics
|
|
||||||
|
|
||||||
self.sigmoid = nn.Sigmoid()
|
|
||||||
|
|
||||||
self.activation2 = Identity() # for diagnostics
|
|
||||||
|
|
||||||
assert kernel_size % 2 == 1
|
|
||||||
|
|
||||||
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
|
|
||||||
channels=bottleneck_dim,
|
|
||||||
kernel_size=kernel_size) if causal else nn.Conv1d(
|
|
||||||
in_channels=bottleneck_dim,
|
|
||||||
out_channels=bottleneck_dim,
|
|
||||||
groups=bottleneck_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
padding=kernel_size // 2)
|
|
||||||
|
|
||||||
self.balancer2 = Balancer(
|
|
||||||
bottleneck_dim, channel_dim=1,
|
|
||||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
|
||||||
max_positive=1.0,
|
|
||||||
min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 0.5)),
|
|
||||||
max_abs=10.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
|
||||||
whitening_limit=_whitening_schedule(7.5),
|
|
||||||
prob=(0.025, 0.25),
|
|
||||||
grad_scale=0.01)
|
|
||||||
|
|
||||||
self.out_proj = ActivationDropoutAndLinear(
|
|
||||||
bottleneck_dim, channels, activation='SwooshR',
|
|
||||||
dropout_p=0.0, initial_scale=0.05,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
x: Tensor,
|
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
chunk_size: int = -1,
|
|
||||||
) -> Tensor:
|
|
||||||
"""Compute convolution module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: Input tensor (#time, batch, channels).
|
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional):
|
|
||||||
(batch, #time), contains True in masked positions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (#time, batch, channels).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
|
||||||
|
|
||||||
x, s = x.chunk(2, dim=-1)
|
|
||||||
s = self.balancer1(s)
|
|
||||||
s = self.sigmoid(s)
|
|
||||||
x = self.activation1(x) # identity.
|
|
||||||
x = x * s
|
|
||||||
x = self.activation2(x) # identity
|
|
||||||
|
|
||||||
# (time, batch, channels)
|
|
||||||
|
|
||||||
# exchange the temporal dimension and the feature dimension
|
|
||||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
|
||||||
|
|
||||||
if src_key_padding_mask is not None:
|
|
||||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
|
||||||
|
|
||||||
if chunk_size >= 0:
|
|
||||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
|
||||||
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
|
||||||
else:
|
|
||||||
x = self.depthwise_conv(x)
|
|
||||||
|
|
||||||
x = self.balancer2(x)
|
|
||||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
|
||||||
|
|
||||||
x = self.whiten(x) # (time, batch, channels)
|
|
||||||
x = self.out_proj(x) # (time, batch, channels)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class ScalarMultiply(nn.Module):
|
class ScalarMultiply(nn.Module):
|
||||||
def __init__(self, scale: float):
|
def __init__(self, scale: float):
|
||||||
@ -1865,7 +1665,7 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
memory_dim = 100
|
memory_dim = 100
|
||||||
|
|
||||||
c = Zipformer2(
|
c = Subformer2(
|
||||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
chunk_size=(4,) if causal else (-1,),
|
chunk_size=(4,) if causal else (-1,),
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/zipformer2/zipformer.py
|
|
||||||
Loading…
x
Reference in New Issue
Block a user