mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement chunking
This commit is contained in:
parent
b2fb504aee
commit
e7e7560bba
@ -84,6 +84,8 @@ class Transducer(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
chunk_size: int = -1,
|
||||||
|
left_context_chunks: int = -1,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -104,6 +106,9 @@ class Transducer(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
|
chunk_size, left_context_chunks:
|
||||||
|
For chunkwise causal training; will be passed to the zipformer encoder.
|
||||||
|
chunk_size is specified in frames at 50Hz, i.e. after 2x downsampling.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -119,7 +124,8 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
encoder_out, x_lens = self.encoder(x, x_lens, chunk_size=chunk_size,
|
||||||
|
left_context_chunks=left_context_chunks)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
|||||||
@ -1014,12 +1014,13 @@ def ScaledConv2d(*args,
|
|||||||
initial_scale: float = 1.0,
|
initial_scale: float = 1.0,
|
||||||
**kwargs ) -> nn.Conv2d:
|
**kwargs ) -> nn.Conv2d:
|
||||||
"""
|
"""
|
||||||
Behaves like a constructor of a modified version of nn.Conv1d
|
Behaves like a constructor of a modified version of nn.Conv2d
|
||||||
that gives an easy way to set the default initial parameter scale.
|
that gives an easy way to set the default initial parameter scale.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Accepts the standard args and kwargs that nn.Linear accepts
|
Accepts the standard args and kwargs that nn.Linear accepts
|
||||||
e.g. in_features, out_features, bias=False.
|
e.g. in_features, out_features, bias=False, but:
|
||||||
|
NO PADDING-RELATED ARGS.
|
||||||
|
|
||||||
initial_scale: you can override this if you want to increase
|
initial_scale: you can override this if you want to increase
|
||||||
or decrease the initial magnitude of the module's output
|
or decrease the initial magnitude of the module's output
|
||||||
@ -1037,6 +1038,132 @@ def ScaledConv2d(*args,
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Behaves like a depthwise 1d convolution, except that it is causal in
|
||||||
|
a chunkwise way, as if we had a block-triangular attention mask.
|
||||||
|
The chunk size is provided at test time (it should probably be
|
||||||
|
kept in sync with the attention mask).
|
||||||
|
|
||||||
|
This has a little more than twice the parameters of a conventional
|
||||||
|
depthwise conv1d module: we implement it by having one
|
||||||
|
depthwise convolution, of half the width, that is causal (via
|
||||||
|
right-padding); and one depthwise convolution that is applied only
|
||||||
|
within chunks, that we multiply by a scaling factor which depends
|
||||||
|
on the position within the chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Accepts the standard args and kwargs that nn.Linear accepts
|
||||||
|
e.g. in_features, out_features, bias=False.
|
||||||
|
|
||||||
|
initial_scale: you can override this if you want to increase
|
||||||
|
or decrease the initial magnitude of the module's output
|
||||||
|
(affects the initialization of weight_scale and bias_scale).
|
||||||
|
Another option, if you want to do something like this, is
|
||||||
|
to re-initialize the parameters.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
initial_scale: float = 1.0,
|
||||||
|
bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
|
||||||
|
half_kernel_size = (kernel_size + 1) // 2
|
||||||
|
# will pad manually, on one side.
|
||||||
|
self.causal_conv = nn.Conv1d(in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
groups=channels,
|
||||||
|
kernel_size=half_kernel_size,
|
||||||
|
padding=0,
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
self.chunkwise_conv = nn.Conv1d(in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
groups=channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=kernel_size // 2,
|
||||||
|
bias=bias)
|
||||||
|
|
||||||
|
# first row is correction factors added to the scale near the left edge of the chunk,
|
||||||
|
# second row is correction factors added to the scale near the right edge of the chunk,
|
||||||
|
# both of these are added to a default scale of 1.0.
|
||||||
|
self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.causal_conv.weight[:] *= initial_scale
|
||||||
|
self.chunkwise_conv.weight[:] *= initial_scale
|
||||||
|
if bias:
|
||||||
|
torch.nn.init.uniform_(self.causal_conv.bias,
|
||||||
|
-0.1 * initial_scale,
|
||||||
|
0.1 * initial_scale)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: Tensor,
|
||||||
|
chunk_size: int = -1) -> Tensor:
|
||||||
|
"""
|
||||||
|
Forward function. Args:
|
||||||
|
x: a Tensor of shape (batch_size, channels, seq_len)
|
||||||
|
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
|
||||||
|
"""
|
||||||
|
(batch_size, num_channels, seq_len) = x.shape
|
||||||
|
|
||||||
|
half_kernel_size = self.kernel_size + 1 // 2
|
||||||
|
# left_pad is half_kernel_size - 1 where half_kernel_size is the size used
|
||||||
|
# in the causal conv. It's the amount by which we must pad on the left,
|
||||||
|
# to make the convolution causal.
|
||||||
|
left_pad = self.kernel_size // 2
|
||||||
|
|
||||||
|
if chunk_size < 0:
|
||||||
|
chunk_size = seq_len
|
||||||
|
right_pad = -seq_len % chunk_size
|
||||||
|
|
||||||
|
x = torch.nn.functional.pad(x, (left_pad, right_pad))
|
||||||
|
|
||||||
|
x_causal = self.causal_conv(x[..., :seq_len + left_pad])
|
||||||
|
assert x_causal.shape == (batch_size, num_channels, seq_len)
|
||||||
|
|
||||||
|
x_chunk = x[..., left_pad:]
|
||||||
|
num_chunks = x_chunk.shape[2] // chunk_size
|
||||||
|
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
|
||||||
|
x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(batch_size * num_chunks,
|
||||||
|
num_channels, chunk_size)
|
||||||
|
x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
|
||||||
|
|
||||||
|
chunk_scale = self._get_chunk_scale(chunk_size)
|
||||||
|
|
||||||
|
x_chunk = x_chunk * chunk_scale
|
||||||
|
x_chunk = x_chunk.reshape(batch_size, num_chunks,
|
||||||
|
num_channels, chunk_size).permute(0, 2, 1, 3)
|
||||||
|
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[..., :seq_len]
|
||||||
|
|
||||||
|
return x_chunk + x_causal
|
||||||
|
|
||||||
|
def _get_chunk_scale(self, chunk_size: int):
|
||||||
|
"""Returns tensor of shape (num_channels, chunk_size) that will be used to
|
||||||
|
scale the output of self.chunkwise_conv."""
|
||||||
|
left_edge = self.chunkwise_conv_scale[0]
|
||||||
|
right_edge = self.chunkwise_conv_scale[1]
|
||||||
|
if chunk_size < self.kernel_size:
|
||||||
|
left_edge = left_edge[:, :chunk_size]
|
||||||
|
right_edge = right_edge[:, -chunk_size:]
|
||||||
|
else:
|
||||||
|
t = chunk_size - self.kernel_size
|
||||||
|
channels = left_edge.shape[0]
|
||||||
|
pad = torch.zeros(channels, t,
|
||||||
|
device=left_edge.device,
|
||||||
|
dtype=left_edge.dtype)
|
||||||
|
left_edge = torch.cat((left_edge, pad), dim=-1)
|
||||||
|
right_edge = torch.cat((pad, right_edge), dim=-1)
|
||||||
|
return 1.0 + (left_edge + right_edge)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationBalancer(torch.nn.Module):
|
class ActivationBalancer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Modifies the backpropped derivatives of a function to try to encourage, for
|
Modifies the backpropped derivatives of a function to try to encourage, for
|
||||||
|
|||||||
@ -47,6 +47,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -225,6 +226,23 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-size",
|
||||||
|
type=str,
|
||||||
|
default="-1",
|
||||||
|
help=" Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-left-context-frames",
|
||||||
|
type=str,
|
||||||
|
default="64,128,256,-1",
|
||||||
|
help="Left-contexts for chunkwise training, measured in frames (positive values must be "
|
||||||
|
"multiples of all positive elements of chunk-size). If --chunk-size is specified, "
|
||||||
|
"chunk left-context frames will be chosen randomly from this list."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -526,6 +544,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
warmup_batches=4000.0,
|
warmup_batches=4000.0,
|
||||||
|
causal=(params.chunk_size != "-1"),
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -686,6 +705,26 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunk_info(params: AttributeDict) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns chunk_size and left_context_chunks.
|
||||||
|
"""
|
||||||
|
chunk_sizes = list(map(int, params.chunk_size.split(',')))
|
||||||
|
n = len(chunk_sizes)
|
||||||
|
chunk_size = random.choice(chunk_sizes)
|
||||||
|
if chunk_size == -1:
|
||||||
|
left_context_chunks = -1
|
||||||
|
else:
|
||||||
|
chunk_left_context_frames = list(map(int, params.chunk_left_context_frames.split(',')))
|
||||||
|
m = len(chunk_left_context_frames)
|
||||||
|
left_context_frames = random.choice(chunk_left_context_frames)
|
||||||
|
if left_context_frames != -1:
|
||||||
|
assert left_context_frames % chunk_size == 0, "Invalid --chunk-left-context-frames value"
|
||||||
|
# Note: in Python, -1 // n == -1 for n > 0
|
||||||
|
left_context_chunks = left_context_frames // chunk_size
|
||||||
|
return chunk_size, left_context_chunks
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -731,6 +770,8 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
|
chunk_size, left_context_chunks = get_chunk_info(params)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
@ -739,6 +780,8 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
left_context_chunks=left_context_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
|
|||||||
@ -37,6 +37,7 @@ from scaling import (
|
|||||||
SwooshL,
|
SwooshL,
|
||||||
SwooshR,
|
SwooshR,
|
||||||
TanSwish,
|
TanSwish,
|
||||||
|
ChunkCausalDepthwiseConv1d,
|
||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
@ -96,8 +97,12 @@ class Zipformer(EncoderInterface):
|
|||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
warmup_batches (float): number of batches to warm up over; this controls
|
warmup_batches (float): number of batches to warm up over; this controls
|
||||||
dropout of encoder layers.
|
dropout of encoder layers.
|
||||||
|
causal (bool): if True, support chunkwise causal convolution. This should
|
||||||
|
not hurt WER as no modeling power is lost, but the convolution modules will be
|
||||||
|
slightly slower and use more memory. Enables use of the chunk_size and
|
||||||
|
left_context_chunk options in forward(), which simulates streaming
|
||||||
|
decoding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
@ -116,6 +121,7 @@ class Zipformer(EncoderInterface):
|
|||||||
pos_dim: int = 192,
|
pos_dim: int = 192,
|
||||||
dropout: FloatLike = None, # see code below for default
|
dropout: FloatLike = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Zipformer, self).__init__()
|
super(Zipformer, self).__init__()
|
||||||
|
|
||||||
@ -144,6 +150,7 @@ class Zipformer(EncoderInterface):
|
|||||||
self.num_features = num_features # int
|
self.num_features = num_features # int
|
||||||
self.output_downsampling_factor = output_downsampling_factor # int
|
self.output_downsampling_factor = output_downsampling_factor # int
|
||||||
self.downsampling_factor = downsampling_factor # tuple
|
self.downsampling_factor = downsampling_factor # tuple
|
||||||
|
self.downsampling_factor_gcd = next(n for n in range(1, 10000) if all(n % d == 0 for d in downsampling_factor))
|
||||||
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
||||||
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
||||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||||
@ -153,8 +160,7 @@ class Zipformer(EncoderInterface):
|
|||||||
num_heads = _to_tuple(num_heads)
|
num_heads = _to_tuple(num_heads)
|
||||||
attention_share_layers = _to_tuple(attention_share_layers)
|
attention_share_layers = _to_tuple(attention_share_layers)
|
||||||
feedforward_dim = _to_tuple(feedforward_dim)
|
feedforward_dim = _to_tuple(feedforward_dim)
|
||||||
cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||||
|
|
||||||
|
|
||||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
@ -187,6 +193,7 @@ class Zipformer(EncoderInterface):
|
|||||||
feedforward_dim=feedforward_dim[i],
|
feedforward_dim=feedforward_dim[i],
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
cnn_module_kernel=cnn_module_kernel[i],
|
cnn_module_kernel=cnn_module_kernel[i],
|
||||||
|
causal=causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||||
@ -314,6 +321,8 @@ class Zipformer(EncoderInterface):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor,
|
self, x: torch.Tensor, x_lens: torch.Tensor,
|
||||||
|
chunk_size: int = -1,
|
||||||
|
left_context_chunks: int = -1,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -322,6 +331,14 @@ class Zipformer(EncoderInterface):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
`x` before padding.
|
`x` before padding.
|
||||||
|
chunk_size: Number of frames per chunk (only set this if causal == True).
|
||||||
|
Must divide all elements of downsampling_factor. At 50hz frame
|
||||||
|
rate, i.e. after encoder_embed. If not specified, no chunking.
|
||||||
|
left_context_chunks: Number of left-context chunks for each chunk (affects
|
||||||
|
attention mask); only set this if chunk_size specified. If -1, there
|
||||||
|
is no limit on the left context. If not -1, require:
|
||||||
|
left_context_chunks * context_size >= downsampling_factor[i] *
|
||||||
|
cnn_module_kernel[i] // 2.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
||||||
@ -340,11 +357,13 @@ class Zipformer(EncoderInterface):
|
|||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
lengths = (x_lens - 7) // 2
|
lengths = (x_lens - 7) // 2
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
feature_masks = self.get_feature_masks(x)
|
feature_masks = self.get_feature_masks(x)
|
||||||
|
|
||||||
|
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
||||||
|
|
||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
ds = self.downsampling_factor[i]
|
ds = self.downsampling_factor[i]
|
||||||
if self.skip_layers[i] is not None:
|
if self.skip_layers[i] is not None:
|
||||||
@ -361,8 +380,12 @@ class Zipformer(EncoderInterface):
|
|||||||
else:
|
else:
|
||||||
x = skip_x
|
x = skip_x
|
||||||
x = module(x,
|
x = module(x,
|
||||||
|
chunk_size=chunk_size,
|
||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
src_key_padding_mask=(None if src_key_padding_mask is None
|
||||||
|
else src_key_padding_mask[...,::ds]),
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
outputs.append(x)
|
outputs.append(x)
|
||||||
|
|
||||||
def get_full_dim_output():
|
def get_full_dim_output():
|
||||||
@ -395,6 +418,42 @@ class Zipformer(EncoderInterface):
|
|||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
def _get_attn_mask(self, x: Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
left_context_chunks: int
|
||||||
|
) -> Optional[Tensor]:
|
||||||
|
"""
|
||||||
|
Return None if chunk_size == -1, else return attention mask of shape
|
||||||
|
(seq_len, seq_len), interpreted as (tgt_seq_len, src_seq_len). True
|
||||||
|
means a masked position.
|
||||||
|
Args:
|
||||||
|
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
||||||
|
chunk_size: chunk size, must divide
|
||||||
|
"""
|
||||||
|
if chunk_size <= 0:
|
||||||
|
return None
|
||||||
|
assert all(chunk_size % d == 0 for d in self.downsampling_factor)
|
||||||
|
if left_context_chunks >= 0:
|
||||||
|
num_encoders = len(self.encoder_dim)
|
||||||
|
assert all (chunk_size * left_context_chunks >=
|
||||||
|
(self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
|
||||||
|
for i in range(num_encoders))
|
||||||
|
else:
|
||||||
|
left_context_chunks = 1000000
|
||||||
|
|
||||||
|
seq_len = x.shape[0]
|
||||||
|
|
||||||
|
# t is frame index, shape (seq_len,)
|
||||||
|
t = torch.arange(seq_len, dtype=torch.int32)
|
||||||
|
# c is chunk index for each frame, shape (seq_len,)
|
||||||
|
c = t // chunk_size
|
||||||
|
src_c = c
|
||||||
|
tgt_c = c.unsqueeze(-1)
|
||||||
|
|
||||||
|
attn_mask = torch.logical_or(src_c > tgt_c,
|
||||||
|
src_c < tgt_c - left_context_chunks)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.info(f"attn_mask = {attn_mask}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -434,6 +493,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: FloatLike = 0.1,
|
dropout: FloatLike = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
|
causal: bool = False,
|
||||||
# layer_skip_rate will be overwritten to change warmup begin and end times.
|
# layer_skip_rate will be overwritten to change warmup begin and end times.
|
||||||
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
# treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom()
|
||||||
# to work correctly.
|
# to work correctly.
|
||||||
@ -487,7 +547,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
hidden_channels=3 * embed_dim // 4)
|
hidden_channels=3 * embed_dim // 4)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(embed_dim,
|
self.conv_module = ConvolutionModule(embed_dim,
|
||||||
cnn_module_kernel)
|
cnn_module_kernel,
|
||||||
|
causal=causal)
|
||||||
|
|
||||||
|
|
||||||
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
#self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||||
@ -566,27 +627,24 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
src_mask: Optional[Tensor] = None,
|
chunk_size: int = -1,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
attn_weights: Optional[Tensor] = None,
|
attn_weights: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder layer (required).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
||||||
src_mask: the mask for the src sequence (optional).
|
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
attn_weights: possibly attention weights computed by the previous layer,
|
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||||
to be used if self.self_attn_weights is None
|
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||||
|
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||||
Shape:
|
True means masked position. May be None.
|
||||||
src: (S, N, E).
|
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||||
pos_emb: (N, 2*S-1, E)
|
masked position. May be None.
|
||||||
src_mask: (S, S).
|
|
||||||
src_key_padding_mask: (N, S).
|
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(x, attn_weights) where x has the same shape as src, and attn_weights are of
|
(x, attn_weights) where x has the same shape as src, and attn_weights are of
|
||||||
@ -602,7 +660,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
attn_weights = self.self_attn_weights(
|
attn_weights = self.self_attn_weights(
|
||||||
src,
|
src,
|
||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
attn_mask=src_mask,
|
attn_mask=attn_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
# else rely on the ones passed in
|
# else rely on the ones passed in
|
||||||
@ -642,7 +700,8 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
if torch.jit.is_scripting() or random.random() >= float(self.conv_skip_rate):
|
||||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
src = src + self.conv_module(src, chunk_size=chunk_size,
|
||||||
|
src_key_padding_mask=src_key_padding_mask)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate):
|
if torch.jit.is_scripting() or random.random() >= float(self.ff2_skip_rate):
|
||||||
src = src + self.balancer_ff2(self.feed_forward2(src))
|
src = src + self.balancer_ff2(self.feed_forward2(src))
|
||||||
@ -660,7 +719,6 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
return src, attn_weights
|
return src, attn_weights
|
||||||
|
|
||||||
|
|
||||||
class ZipformerEncoder(nn.Module):
|
class ZipformerEncoder(nn.Module):
|
||||||
r"""ZipformerEncoder is a stack of N encoder layers
|
r"""ZipformerEncoder is a stack of N encoder layers
|
||||||
|
|
||||||
@ -713,32 +771,29 @@ class ZipformerEncoder(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
|
chunk_size: int = -1,
|
||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Union[Tensor, float] = 1.0,
|
||||||
mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
|
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
by at every layer.
|
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||||
mask: the mask for the src sequence (optional).
|
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||||
|
True means masked position. May be None.
|
||||||
|
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||||
|
masked position. May be None.
|
||||||
|
|
||||||
Shape:
|
Returns: a Tensor with the same shape as src.
|
||||||
src: (S, N, E).
|
|
||||||
pos_emb: (N, 2*S-1, E)
|
|
||||||
mask: (S, S).
|
|
||||||
src_key_padding_mask: (N, S).
|
|
||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
|
||||||
|
|
||||||
Returns: (x, x_no_combine), both of shape (S, N, E)
|
|
||||||
"""
|
"""
|
||||||
pos_emb = self.encoder_pos(src)
|
pos_emb = self.encoder_pos(src)
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
|
|
||||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||||
|
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
@ -749,7 +804,8 @@ class ZipformerEncoder(nn.Module):
|
|||||||
output, attn_weights = mod(
|
output, attn_weights = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
chunk_size=chunk_size,
|
||||||
|
attn_mask=attn_mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
attn_weights=attn_weights,
|
attn_weights=attn_weights,
|
||||||
)
|
)
|
||||||
@ -774,7 +830,7 @@ class DownsampledZipformerEncoder(nn.Module):
|
|||||||
super(DownsampledZipformerEncoder, self).__init__()
|
super(DownsampledZipformerEncoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsample = SimpleDownsample(input_dim, output_dim,
|
self.downsample = SimpleDownsample(input_dim, output_dim,
|
||||||
downsample, dropout)
|
downsample, dropout)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||||
self.out_combiner = SimpleCombiner(input_dim,
|
self.out_combiner = SimpleCombiner(input_dim,
|
||||||
@ -784,39 +840,37 @@ class DownsampledZipformerEncoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
|
chunk_size: int = -1,
|
||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Union[Tensor, float] = 1.0,
|
||||||
mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
r"""Downsample, go through encoder, upsample.
|
r"""Downsample, go through encoder, upsample.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
by at every layer. feature_mask is expected to be already downsampled by
|
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||||
self.downsample_factor.
|
attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||||
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
|
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||||
this, if we are to support it. Won't work correctly yet.
|
True means masked position. May be None.
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional). Should
|
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||||
be downsampled already.
|
masked position. May be None.
|
||||||
|
|
||||||
Shape:
|
Returns: a Tensor with the same shape as src.
|
||||||
src: (S, N, E).
|
|
||||||
mask: (S, S).
|
|
||||||
src_key_padding_mask: (N, S).
|
|
||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
|
||||||
|
|
||||||
Returns: output of shape (S, N, F) where F is the number of output features
|
|
||||||
(output_dim to constructor)
|
|
||||||
"""
|
"""
|
||||||
src_orig = src
|
src_orig = src
|
||||||
src = self.downsample(src)
|
src = self.downsample(src)
|
||||||
ds = self.downsample_factor
|
ds = self.downsample_factor
|
||||||
if mask is not None:
|
if attn_mask is not None:
|
||||||
mask = mask[::ds,::ds]
|
attn_mask = attn_mask[::ds,::ds]
|
||||||
|
|
||||||
src = self.encoder(
|
src = self.encoder(
|
||||||
src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask,
|
src,
|
||||||
|
chunk_size=chunk_size // ds,
|
||||||
|
feature_mask=feature_mask,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
src = self.upsample(src)
|
src = self.upsample(src)
|
||||||
# remove any extra frames that are not a multiple of downsample_factor
|
# remove any extra frames that are not a multiple of downsample_factor
|
||||||
@ -990,8 +1044,9 @@ class SmallConvolutionModule(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
|
||||||
self.depthwise_conv = nn.Conv1d(
|
channels=channels,
|
||||||
|
kernel_size=kernel_size) if causal else nn.Conv1d(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
out_channels=channels,
|
out_channels=channels,
|
||||||
groups=channels,
|
groups=channels,
|
||||||
@ -1139,13 +1194,13 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tensor:
|
def forward(self, x: torch.Tensor) -> Tensor:
|
||||||
"""Add positional encoding.
|
"""Create positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (time, batch, `*`).
|
x (torch.Tensor): Input tensor (time, batch, `*`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
positional embedding, of shape (1, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
@ -1235,6 +1290,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
|
chunk_size: int = -1,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@ -1242,6 +1298,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: input of shape (seq_len, batch_size, embed_dim)
|
x: input of shape (seq_len, batch_size, embed_dim)
|
||||||
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim)
|
pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim)
|
||||||
|
chunk_size
|
||||||
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
key_padding_mask: a bool tensor of shape (batch_size, seq_len). Positions that
|
||||||
are True in this mask will be ignored as sources in the attention weighting.
|
are True in this mask will be ignored as sources in the attention weighting.
|
||||||
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
|
attn_mask: mask of shape (seq_len, seq_len) or (batch_size, seq_len, seq_len),
|
||||||
@ -1687,9 +1744,8 @@ class ConvolutionModule(nn.Module):
|
|||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, channels: int, kernel_size: int,
|
self, channels: int, kernel_size: int, causal: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct a ConvolutionModule object."""
|
"""Construct a ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
@ -1697,7 +1753,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
bottleneck_dim = channels
|
bottleneck_dim = channels
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
self.in_proj = nn.Linear(
|
self.in_proj = nn.Linear(
|
||||||
channels, 2 * bottleneck_dim,
|
channels, 2 * bottleneck_dim,
|
||||||
@ -1706,7 +1762,6 @@ class ConvolutionModule(nn.Module):
|
|||||||
# sigmoid in glu.
|
# sigmoid in glu.
|
||||||
self.in_proj.lr_scale = 0.9
|
self.in_proj.lr_scale = 0.9
|
||||||
|
|
||||||
|
|
||||||
# after in_proj we put x through a gated linear unit (nn.functional.glu).
|
# after in_proj we put x through a gated linear unit (nn.functional.glu).
|
||||||
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
||||||
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
||||||
@ -1734,15 +1789,17 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
self.activation2 = Identity() # for diagnostics
|
self.activation2 = Identity() # for diagnostics
|
||||||
|
|
||||||
self.depthwise_conv = nn.Conv1d(
|
assert kernel_size % 2 == 1
|
||||||
bottleneck_dim,
|
|
||||||
bottleneck_dim,
|
self.depthwise_conv = ChunkCausalDepthwiseConv1d(
|
||||||
kernel_size,
|
channels=bottleneck_dim,
|
||||||
stride=1,
|
kernel_size=kernel_size) if causal else nn.Conv1d(
|
||||||
padding=(kernel_size - 1) // 2,
|
in_channels=bottleneck_dim,
|
||||||
|
out_channels=bottleneck_dim,
|
||||||
groups=bottleneck_dim,
|
groups=bottleneck_dim,
|
||||||
bias=True,
|
kernel_size=kernel_size,
|
||||||
)
|
padding=kernel_size // 2)
|
||||||
|
|
||||||
|
|
||||||
self.balancer2 = Balancer(
|
self.balancer2 = Balancer(
|
||||||
bottleneck_dim, channel_dim=1,
|
bottleneck_dim, channel_dim=1,
|
||||||
@ -1768,6 +1825,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
chunk_size: int = -1,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
|
||||||
@ -1798,8 +1856,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
if src_key_padding_mask is not None:
|
if src_key_padding_mask is not None:
|
||||||
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
if chunk_size >= 0:
|
||||||
x = self.depthwise_conv(x)
|
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||||
|
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
||||||
|
else:
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
|
||||||
x = self.balancer2(x)
|
x = self.balancer2(x)
|
||||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||||
@ -2186,7 +2247,7 @@ def _test_random_combine():
|
|||||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||||
|
|
||||||
|
|
||||||
def _test_zipformer_main():
|
def _test_zipformer_main(causal: bool = False):
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
@ -2194,7 +2255,8 @@ def _test_zipformer_main():
|
|||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
|
|
||||||
c = Zipformer(
|
c = Zipformer(
|
||||||
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4)
|
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
|
||||||
|
causal=causal,
|
||||||
)
|
)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
@ -2202,6 +2264,7 @@ def _test_zipformer_main():
|
|||||||
f = c(
|
f = c(
|
||||||
torch.randn(batch_size, seq_len, feature_dim),
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
chunk_size=4 if causal else -1,
|
||||||
)
|
)
|
||||||
f[0].sum().backward()
|
f[0].sum().backward()
|
||||||
c.eval()
|
c.eval()
|
||||||
@ -2212,9 +2275,11 @@ def _test_zipformer_main():
|
|||||||
f # to remove flake8 warnings
|
f # to remove flake8 warnings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
_test_random_combine()
|
_test_random_combine()
|
||||||
_test_zipformer_main()
|
_test_zipformer_main(False)
|
||||||
|
_test_zipformer_main(True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user