Implement chunking

This commit is contained in:
Daniel Povey 2023-02-10 14:53:47 +08:00
parent b2fb504aee
commit e7e7560bba
4 changed files with 324 additions and 83 deletions

View File

@ -84,6 +84,8 @@ class Transducer(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
chunk_size: int = -1,
left_context_chunks: int = -1,
) -> torch.Tensor:
"""
Args:
@ -104,6 +106,9 @@ class Transducer(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
chunk_size, left_context_chunks:
For chunkwise causal training; will be passed to the zipformer encoder.
chunk_size is specified in frames at 50Hz, i.e. after 2x downsampling.
Returns:
Return the transducer loss.
@ -119,7 +124,8 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
encoder_out, x_lens = self.encoder(x, x_lens, chunk_size=chunk_size,
left_context_chunks=left_context_chunks)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network

View File

@ -1014,12 +1014,13 @@ def ScaledConv2d(*args,
initial_scale: float = 1.0,
**kwargs ) -> nn.Conv2d:
"""
Behaves like a constructor of a modified version of nn.Conv1d
Behaves like a constructor of a modified version of nn.Conv2d
that gives an easy way to set the default initial parameter scale.
Args:
Accepts the standard args and kwargs that nn.Linear accepts
e.g. in_features, out_features, bias=False.
e.g. in_features, out_features, bias=False, but:
NO PADDING-RELATED ARGS.
initial_scale: you can override this if you want to increase
or decrease the initial magnitude of the module's output
@ -1037,6 +1038,132 @@ def ScaledConv2d(*args,
return ans
class ChunkCausalDepthwiseConv1d(torch.nn.Module):
"""
Behaves like a depthwise 1d convolution, except that it is causal in
a chunkwise way, as if we had a block-triangular attention mask.
The chunk size is provided at test time (it should probably be
kept in sync with the attention mask).
This has a little more than twice the parameters of a conventional
depthwise conv1d module: we implement it by having one
depthwise convolution, of half the width, that is causal (via
right-padding); and one depthwise convolution that is applied only
within chunks, that we multiply by a scaling factor which depends
on the position within the chunk.
Args:
Accepts the standard args and kwargs that nn.Linear accepts
e.g. in_features, out_features, bias=False.
initial_scale: you can override this if you want to increase
or decrease the initial magnitude of the module's output
(affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is
to re-initialize the parameters.
"""
def __init__(self,
channels: int,
kernel_size: int,
initial_scale: float = 1.0,
bias: bool = True):
super().__init__()
assert kernel_size % 2 == 1
half_kernel_size = (kernel_size + 1) // 2
# will pad manually, on one side.
self.causal_conv = nn.Conv1d(in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=half_kernel_size,
padding=0,
bias=True)
self.chunkwise_conv = nn.Conv1d(in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=bias)
# first row is correction factors added to the scale near the left edge of the chunk,
# second row is correction factors added to the scale near the right edge of the chunk,
# both of these are added to a default scale of 1.0.
self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
self.kernel_size = kernel_size
with torch.no_grad():
self.causal_conv.weight[:] *= initial_scale
self.chunkwise_conv.weight[:] *= initial_scale
if bias:
torch.nn.init.uniform_(self.causal_conv.bias,
-0.1 * initial_scale,
0.1 * initial_scale)
def forward(self,
x: Tensor,
chunk_size: int = -1) -> Tensor:
"""
Forward function. Args:
x: a Tensor of shape (batch_size, channels, seq_len)
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
"""
(batch_size, num_channels, seq_len) = x.shape
half_kernel_size = self.kernel_size + 1 // 2
# left_pad is half_kernel_size - 1 where half_kernel_size is the size used
# in the causal conv. It's the amount by which we must pad on the left,
# to make the convolution causal.
left_pad = self.kernel_size // 2
if chunk_size < 0:
chunk_size = seq_len
right_pad = -seq_len % chunk_size
x = torch.nn.functional.pad(x, (left_pad, right_pad))
x_causal = self.causal_conv(x[..., :seq_len + left_pad])
assert x_causal.shape == (batch_size, num_channels, seq_len)
x_chunk = x[..., left_pad:]
num_chunks = x_chunk.shape[2] // chunk_size
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks,
num_channels, chunk_size)
x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
chunk_scale = self._get_chunk_scale(chunk_size)
x_chunk = x_chunk * chunk_scale
x_chunk = x_chunk.reshape(batch_size, num_chunks,
num_channels, chunk_size).permute(0, 2, 1, 3)
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len]
return x_chunk + x_causal
def _get_chunk_scale(self, chunk_size: int):
"""Returns tensor of shape (num_channels, chunk_size) that will be used to
scale the output of self.chunkwise_conv."""
left_edge = self.chunkwise_conv_scale[0]
right_edge = self.chunkwise_conv_scale[1]
if chunk_size < self.kernel_size:
left_edge = left_edge[:, :chunk_size]
right_edge = right_edge[:, -chunk_size:]
else:
t = chunk_size - self.kernel_size
channels = left_edge.shape[0]
pad = torch.zeros(channels, t,
device=left_edge.device,
dtype=left_edge.dtype)
left_edge = torch.cat((left_edge, pad), dim=-1)
right_edge = torch.cat((pad, right_edge), dim=-1)
return 1.0 + (left_edge + right_edge)
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for

View File

@ -47,6 +47,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse
import copy
import logging
import random
import warnings
from pathlib import Path
from shutil import copyfile
@ -225,6 +226,23 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--chunk-size",
type=str,
default="-1",
help=" Embedding dimension in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--chunk-left-context-frames",
type=str,
default="64,128,256,-1",
help="Left-contexts for chunkwise training, measured in frames (positive values must be "
"multiples of all positive elements of chunk-size). If --chunk-size is specified, "
"chunk left-context frames will be chosen randomly from this list."
)
def get_parser():
parser = argparse.ArgumentParser(
@ -526,6 +544,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=(params.chunk_size != "-1"),
)
return encoder
@ -686,6 +705,26 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
def get_chunk_info(params: AttributeDict) -> Tuple[int, int]:
"""
Returns chunk_size and left_context_chunks.
"""
chunk_sizes = list(map(int, params.chunk_size.split(',')))
n = len(chunk_sizes)
chunk_size = random.choice(chunk_sizes)
if chunk_size == -1:
left_context_chunks = -1
else:
chunk_left_context_frames = list(map(int, params.chunk_left_context_frames.split(',')))
m = len(chunk_left_context_frames)
left_context_frames = random.choice(chunk_left_context_frames)
if left_context_frames != -1:
assert left_context_frames % chunk_size == 0, "Invalid --chunk-left-context-frames value"
# Note: in Python, -1 // n == -1 for n > 0
left_context_chunks = left_context_frames // chunk_size
return chunk_size, left_context_chunks
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -731,6 +770,8 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
chunk_size, left_context_chunks = get_chunk_info(params)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,
@ -739,6 +780,8 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
chunk_size=chunk_size,
left_context_chunks=left_context_chunks,
)
s = params.simple_loss_scale

View File

@ -37,6 +37,7 @@ from scaling import (
SwooshL,
SwooshR,
TanSwish,
ChunkCausalDepthwiseConv1d,
ScaledConv1d,
ScaledConv2d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
@ -96,8 +97,12 @@ class Zipformer(EncoderInterface):
dropout (float): dropout rate
warmup_batches (float): number of batches to warm up over; this controls
dropout of encoder layers.
causal (bool): if True, support chunkwise causal convolution. This should
not hurt WER as no modeling power is lost, but the convolution modules will be
slightly slower and use more memory. Enables use of the chunk_size and
left_context_chunk options in forward(), which simulates streaming
decoding.
"""
def __init__(
self,
num_features: int,
@ -116,6 +121,7 @@ class Zipformer(EncoderInterface):
pos_dim: int = 192,
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
) -> None:
super(Zipformer, self).__init__()
@ -144,6 +150,7 @@ class Zipformer(EncoderInterface):
self.num_features = num_features # int
self.output_downsampling_factor = output_downsampling_factor # int
self.downsampling_factor = downsampling_factor # tuple
self.downsampling_factor_gcd = next(n for n in range(1, 10000) if all(n % d == 0 for d in downsampling_factor))
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
num_encoder_layers = _to_tuple(num_encoder_layers)
@ -153,8 +160,7 @@ class Zipformer(EncoderInterface):
num_heads = _to_tuple(num_heads)
attention_share_layers = _to_tuple(attention_share_layers)
feedforward_dim = _to_tuple(feedforward_dim)
cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
@ -187,6 +193,7 @@ class Zipformer(EncoderInterface):
feedforward_dim=feedforward_dim[i],
dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i],
causal=causal,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
@ -314,6 +321,8 @@ class Zipformer(EncoderInterface):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor,
chunk_size: int = -1,
left_context_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -322,6 +331,14 @@ class Zipformer(EncoderInterface):
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
chunk_size: Number of frames per chunk (only set this if causal == True).
Must divide all elements of downsampling_factor. At 50hz frame
rate, i.e. after encoder_embed. If not specified, no chunking.
left_context_chunks: Number of left-context chunks for each chunk (affects
attention mask); only set this if chunk_size specified. If -1, there
is no limit on the left context. If not -1, require:
left_context_chunks * context_size >= downsampling_factor[i] *
cnn_module_kernel[i] // 2.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
@ -340,11 +357,13 @@ class Zipformer(EncoderInterface):
warnings.simplefilter("ignore")
lengths = (x_lens - 7) // 2
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
src_key_padding_mask = make_pad_mask(lengths)
outputs = []
feature_masks = self.get_feature_masks(x)
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i]
if self.skip_layers[i] is not None:
@ -361,8 +380,12 @@ class Zipformer(EncoderInterface):
else:
x = skip_x
x = module(x,
chunk_size=chunk_size,
feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[...,::ds])
src_key_padding_mask=(None if src_key_padding_mask is None
else src_key_padding_mask[...,::ds]),
attn_mask=attn_mask,
)
outputs.append(x)
def get_full_dim_output():
@ -395,6 +418,42 @@ class Zipformer(EncoderInterface):
return x, lengths
def _get_attn_mask(self, x: Tensor,
chunk_size: int,
left_context_chunks: int
) -> Optional[Tensor]:
"""
Return None if chunk_size == -1, else return attention mask of shape
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
means a masked position.
Args:
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
chunk_size: chunk size, must divide
"""
if chunk_size <= 0:
return None
assert all(chunk_size % d == 0 for d in self.downsampling_factor)
if left_context_chunks >= 0:
num_encoders = len(self.encoder_dim)
assert all (chunk_size * left_context_chunks >=
(self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
for i in range(num_encoders))
else:
left_context_chunks = 1000000
seq_len = x.shape[0]
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32)
# c is chunk index for each frame, shape (seq_len,)
c = t // chunk_size
src_c = c
tgt_c = c.unsqueeze(-1)
attn_mask = torch.logical_or(src_c > tgt_c,
src_c < tgt_c - left_context_chunks)
if __name__ == "__main__":
logging.info(f"attn_mask = {attn_mask}")
@ -434,6 +493,7 @@ class ZipformerEncoderLayer(nn.Module):
feedforward_dim: int,
dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31,
causal: bool = False,
# layer_skip_rate will be overwritten to change warmup begin and end times.
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
# to work correctly.
@ -487,7 +547,8 @@ class ZipformerEncoderLayer(nn.Module):
hidden_channels=3 * embed_dim // 4)
self.conv_module = ConvolutionModule(embed_dim,
cnn_module_kernel)
cnn_module_kernel,
causal=causal)
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
@ -566,27 +627,24 @@ class ZipformerEncoderLayer(nn.Module):
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
chunk_size: int = -1,
attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
attn_weights: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
attn_weights: possibly attention weights computed by the previous layer,
to be used if self.self_attn_weights is None
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Returns:
(x, attn_weights) where x has the same shape as src, and attn_weights are of
@ -602,7 +660,7 @@ class ZipformerEncoderLayer(nn.Module):
attn_weights = self.self_attn_weights(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
attn_mask=attn_mask,
key_padding_mask=src_key_padding_mask,
)
# else rely on the ones passed in
@ -642,7 +700,8 @@ class ZipformerEncoderLayer(nn.Module):
src, attn_weights)
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
src = src + self.conv_module(src, chunk_size=chunk_size,
src_key_padding_mask=src_key_padding_mask)
if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate):
src = src + self.balancer_ff2(self.feed_forward2(src))
@ -660,7 +719,6 @@ class ZipformerEncoderLayer(nn.Module):
return src, attn_weights
class ZipformerEncoder(nn.Module):
r"""ZipformerEncoder is a stack of N encoder layers
@ -713,32 +771,29 @@ class ZipformerEncoder(nn.Module):
def forward(
self,
src: Tensor,
chunk_size: int = -1,
feature_mask: Union[Tensor, float] = 1.0,
mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: (x, x_no_combine), both of shape (S, N, E)
Returns: a Tensor with the same shape as src.
"""
pos_emb = self.encoder_pos(src)
output = src
rnd_seed = src.numel() + random.randint(0, 1000)
output = output * feature_mask
@ -749,7 +804,8 @@ class ZipformerEncoder(nn.Module):
output, attn_weights = mod(
output,
pos_emb,
src_mask=mask,
chunk_size=chunk_size,
attn_mask=attn_mask,
src_key_padding_mask=src_key_padding_mask,
attn_weights=attn_weights,
)
@ -774,7 +830,7 @@ class DownsampledZipformerEncoder(nn.Module):
super(DownsampledZipformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsample = SimpleDownsample(input_dim, output_dim,
downsample, dropout)
downsample, dropout)
self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample)
self.out_combiner = SimpleCombiner(input_dim,
@ -784,39 +840,37 @@ class DownsampledZipformerEncoder(nn.Module):
def forward(self,
src: Tensor,
chunk_size: int = -1,
feature_mask: Union[Tensor, float] = 1.0,
mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
r"""Downsample, go through encoder, upsample.
Args:
src: the sequence to the encoder (required).
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer. feature_mask is expected to be already downsampled by
self.downsample_factor.
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
this, if we are to support it. Won't work correctly yet.
src_key_padding_mask: the mask for the src keys per batch (optional). Should
be downsampled already.
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
Shape:
src: (S, N, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: output of shape (S, N, F) where F is the number of output features
(output_dim to constructor)
Returns: a Tensor with the same shape as src.
"""
src_orig = src
src = self.downsample(src)
ds = self.downsample_factor
if mask is not None:
mask = mask[::ds,::ds]
if attn_mask is not None:
attn_mask = attn_mask[::ds,::ds]
src = self.encoder(
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
src,
chunk_size=chunk_size // ds,
feature_mask=feature_mask,
attn_mask=attn_mask,
src_key_padding_mask=src_key_padding_mask,
)
src = self.upsample(src)
# remove any extra frames that are not a multiple of downsample_factor
@ -990,8 +1044,9 @@ class SmallConvolutionModule(nn.Module):
) -> None:
super().__init__()
self.depthwise_conv = nn.Conv1d(
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
channels=channels,
kernel_size=kernel_size) if causal else nn.Conv1d(
in_channels=channels,
out_channels=channels,
groups=channels,
@ -1139,13 +1194,13 @@ class CompactRelPositionalEncoding(torch.nn.Module):
def forward(self, x: torch.Tensor) -> Tensor:
"""Add positional encoding.
"""Create positional encoding.
Args:
x (torch.Tensor): Input tensor (time, batch, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
positional embedding, of shape (1, 2*time-1, `*`).
"""
self.extend_pe(x)
@ -1235,6 +1290,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self,
x: Tensor,
pos_emb: Tensor,
chunk_size: int = -1,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
) -> Tensor:
@ -1242,6 +1298,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
Args:
x: input of shape (seq_len, batch_size, embed_dim)
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim)
chunk_size
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
are True in this mask will be ignored as sources in the attention weighting.
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
@ -1687,9 +1744,8 @@ class ConvolutionModule(nn.Module):
bias (bool): Whether to use bias in conv layers (default=True).
"""
def __init__(
self, channels: int, kernel_size: int,
self, channels: int, kernel_size: int, causal: bool,
) -> None:
"""Construct a ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
@ -1697,7 +1753,7 @@ class ConvolutionModule(nn.Module):
assert (kernel_size - 1) % 2 == 0
bottleneck_dim = channels
self.causal = causal
self.in_proj = nn.Linear(
channels, 2 * bottleneck_dim,
@ -1706,7 +1762,6 @@ class ConvolutionModule(nn.Module):
# sigmoid in glu.
self.in_proj.lr_scale = 0.9
# after in_proj we put x through a gated linear unit (nn.functional.glu).
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
@ -1734,15 +1789,17 @@ class ConvolutionModule(nn.Module):
self.activation2 = Identity() # for diagnostics
self.depthwise_conv = nn.Conv1d(
bottleneck_dim,
bottleneck_dim,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
assert kernel_size % 2 == 1
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
channels=bottleneck_dim,
kernel_size=kernel_size) if causal else nn.Conv1d(
in_channels=bottleneck_dim,
out_channels=bottleneck_dim,
groups=bottleneck_dim,
bias=True,
)
kernel_size=kernel_size,
padding=kernel_size // 2)
self.balancer2 = Balancer(
bottleneck_dim, channel_dim=1,
@ -1768,6 +1825,7 @@ class ConvolutionModule(nn.Module):
def forward(self,
x: Tensor,
src_key_padding_mask: Optional[Tensor] = None,
chunk_size: int = -1,
) -> Tensor:
"""Compute convolution module.
@ -1798,8 +1856,11 @@ class ConvolutionModule(nn.Module):
if src_key_padding_mask is not None:
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if chunk_size >= 0:
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
x = self.depthwise_conv(x, chunk_size=chunk_size)
else:
x = self.depthwise_conv(x)
x = self.balancer2(x)
x = x.permute(2, 0, 1) # (time, batch, channels)
@ -2186,7 +2247,7 @@ def _test_random_combine():
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_zipformer_main():
def _test_zipformer_main(causal: bool = False):
feature_dim = 50
batch_size = 5
seq_len = 20
@ -2194,7 +2255,8 @@ def _test_zipformer_main():
# Just make sure the forward pass runs.
c = Zipformer(
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4)
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
causal=causal,
)
batch_size = 5
seq_len = 20
@ -2202,6 +2264,7 @@ def _test_zipformer_main():
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
chunk_size=4 if causal else -1,
)
f[0].sum().backward()
c.eval()
@ -2212,9 +2275,11 @@ def _test_zipformer_main():
f # to remove flake8 warnings
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_random_combine()
_test_zipformer_main()
_test_zipformer_main(False)
_test_zipformer_main(True)