mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
separate Conv2dSubsampling from Zipformer
This commit is contained in:
parent
0ec31c84da
commit
55a1abc9da
@ -127,6 +127,7 @@ from icefall.checkpoint import (
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
|
make_pad_mask,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -365,9 +366,15 @@ def decode_one_batch(
|
|||||||
value=LOG_EPS,
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
x, x_lens = model.encoder_embed(feature, feature_lens)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x, x_lens, src_key_padding_mask
|
||||||
)
|
)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
|
|||||||
@ -19,14 +19,11 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import random
|
import random
|
||||||
|
import warnings
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import (
|
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||||
penalize_abs_values_gt,
|
|
||||||
ScaledLinear
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -36,6 +33,7 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
encoder_embed: nn.Module,
|
||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
@ -46,6 +44,10 @@ class Transducer(nn.Module):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
encoder_embed:
|
||||||
|
It is a Convolutional 2D subsampling module. It converts
|
||||||
|
an input of shape (N, T, idim) to an output of of shape
|
||||||
|
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||||
encoder:
|
encoder:
|
||||||
It is the transcription network in the paper. Its accepts
|
It is the transcription network in the paper. Its accepts
|
||||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
@ -64,18 +66,22 @@ class Transducer(nn.Module):
|
|||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
assert hasattr(decoder, "blank_id")
|
assert hasattr(decoder, "blank_id")
|
||||||
|
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
self.simple_am_proj = ScaledLinear(
|
self.simple_am_proj = ScaledLinear(
|
||||||
encoder_dim, vocab_size, initial_scale=0.25,
|
encoder_dim,
|
||||||
|
vocab_size,
|
||||||
|
initial_scale=0.25,
|
||||||
)
|
)
|
||||||
self.simple_lm_proj = ScaledLinear(
|
self.simple_lm_proj = ScaledLinear(
|
||||||
decoder_dim, vocab_size, initial_scale=0.25,
|
decoder_dim,
|
||||||
|
vocab_size,
|
||||||
|
initial_scale=0.25,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -119,7 +125,15 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
|
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
@ -142,7 +156,9 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
y_padded = y_padded.to(torch.int64)
|
y_padded = y_padded.to(torch.int64)
|
||||||
boundary = torch.zeros(
|
boundary = torch.zeros(
|
||||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
(encoder_out.size(0), 4),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=encoder_out.device,
|
||||||
)
|
)
|
||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
@ -150,9 +166,9 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
#if self.training and random.random() < 0.25:
|
# if self.training and random.random() < 0.25:
|
||||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||||
#if self.training and random.random() < 0.25:
|
# if self.training and random.random() < 0.25:
|
||||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
|||||||
280
egs/librispeech/ASR/pruned_transducer_stateless7/subsampling.py
Normal file
280
egs/librispeech/ASR/pruned_transducer_stateless7/subsampling.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from scaling import (
|
||||||
|
Balancer,
|
||||||
|
BiasNorm,
|
||||||
|
Dropout3,
|
||||||
|
FloatLike,
|
||||||
|
Optional,
|
||||||
|
ScaledConv2d,
|
||||||
|
ScaleGrad,
|
||||||
|
ScheduledFloat,
|
||||||
|
SwooshL,
|
||||||
|
SwooshR,
|
||||||
|
Whiten,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNeXt(nn.Module):
|
||||||
|
"""
|
||||||
|
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
hidden_ratio: int = 3,
|
||||||
|
kernel_size: Tuple[int, int] = (7, 7),
|
||||||
|
layerdrop_rate: FloatLike = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
||||||
|
hidden_channels = channels * hidden_ratio
|
||||||
|
if layerdrop_rate is None:
|
||||||
|
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
||||||
|
self.layerdrop_rate = layerdrop_rate
|
||||||
|
|
||||||
|
self.depthwise_conv = nn.Conv2d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
groups=channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pointwise_conv1 = nn.Conv2d(
|
||||||
|
in_channels=channels, out_channels=hidden_channels, kernel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hidden_balancer = Balancer(
|
||||||
|
hidden_channels,
|
||||||
|
channel_dim=1,
|
||||||
|
min_positive=0.3,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.75,
|
||||||
|
max_abs=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation = SwooshL()
|
||||||
|
self.pointwise_conv2 = ScaledConv2d(
|
||||||
|
in_channels=hidden_channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=1,
|
||||||
|
initial_scale=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_balancer = Balancer(
|
||||||
|
channels,
|
||||||
|
channel_dim=1,
|
||||||
|
min_positive=0.4,
|
||||||
|
max_positive=0.6,
|
||||||
|
min_abs=1.0,
|
||||||
|
max_abs=6.0,
|
||||||
|
)
|
||||||
|
self.out_whiten = Whiten(
|
||||||
|
num_groups=1,
|
||||||
|
whitening_limit=5.0,
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if torch.jit.is_scripting() or not self.training:
|
||||||
|
return self.forward_internal(x)
|
||||||
|
layerdrop_rate = float(self.layerdrop_rate)
|
||||||
|
|
||||||
|
if layerdrop_rate != 0.0:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
mask = (
|
||||||
|
torch.rand(
|
||||||
|
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
|
> layerdrop_rate
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
# turns out this caching idea does not work with --world-size > 1
|
||||||
|
# return caching_eval(self.forward_internal, x, mask)
|
||||||
|
return self.forward_internal(x, mask)
|
||||||
|
|
||||||
|
def forward_internal(
|
||||||
|
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||||
|
|
||||||
|
The returned value has the same shape as x.
|
||||||
|
"""
|
||||||
|
bypass = x
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = self.pointwise_conv1(x)
|
||||||
|
x = self.hidden_balancer(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.pointwise_conv2(x)
|
||||||
|
|
||||||
|
if layer_skip_mask is not None:
|
||||||
|
x = x * layer_skip_mask
|
||||||
|
|
||||||
|
x = bypass + x
|
||||||
|
x = self.out_balancer(x)
|
||||||
|
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||||
|
x = self.out_whiten(x)
|
||||||
|
x = x.transpose(1, 3) # (N, C, H, W)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/2 length).
|
||||||
|
|
||||||
|
Convert an input of shape (N, T, idim) to an output
|
||||||
|
with shape (N, T', odim), where
|
||||||
|
T' = (T-3)//2 - 2 == (T-7)//2
|
||||||
|
|
||||||
|
It is based on
|
||||||
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
layer1_channels: int = 8,
|
||||||
|
layer2_channels: int = 32,
|
||||||
|
layer3_channels: int = 128,
|
||||||
|
dropout: FloatLike = 0.1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_channels:
|
||||||
|
Number of channels in. The input shape is (N, T, in_channels).
|
||||||
|
Caution: It requires: T >=7, in_channels >=7
|
||||||
|
out_channels
|
||||||
|
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer1
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer2
|
||||||
|
bottleneck:
|
||||||
|
bottleneck dimension for 1d squeeze-excite
|
||||||
|
"""
|
||||||
|
assert in_channels >= 7
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# The ScaleGrad module is there to prevent the gradients
|
||||||
|
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
||||||
|
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
||||||
|
# training. (The second one is necessary to stop its bias from getting
|
||||||
|
# a too-large gradient).
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=layer1_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=(0, 1), # (time, freq)
|
||||||
|
),
|
||||||
|
ScaleGrad(0.2),
|
||||||
|
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
|
||||||
|
SwooshR(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=layer1_channels,
|
||||||
|
out_channels=layer2_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
|
||||||
|
SwooshR(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=layer2_channels,
|
||||||
|
out_channels=layer3_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=(1, 2), # (time, freq)
|
||||||
|
),
|
||||||
|
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
|
||||||
|
SwooshR(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# just one convnext layer
|
||||||
|
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||||
|
|
||||||
|
out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||||
|
|
||||||
|
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
||||||
|
# use a larger than normal grad_scale on this whitening module; there is
|
||||||
|
# only one such module, so there is not a concern about adding together
|
||||||
|
# many copies of this extra gradient term.
|
||||||
|
self.out_whiten = Whiten(
|
||||||
|
num_groups=1,
|
||||||
|
whitening_limit=ScheduledFloat(
|
||||||
|
(0.0, 4.0), (20000.0, 8.0), default=4.0
|
||||||
|
),
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.02,
|
||||||
|
)
|
||||||
|
|
||||||
|
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
||||||
|
# getting large, there is an unnecessary degree of freedom.
|
||||||
|
self.out_norm = BiasNorm(out_channels)
|
||||||
|
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is (N, T, idim).
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||||
|
- output lengths, of shape (batch_size,)
|
||||||
|
"""
|
||||||
|
# On entry, x is (N, T, idim)
|
||||||
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
|
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
||||||
|
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
||||||
|
# gradients.
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.convnext(x)
|
||||||
|
|
||||||
|
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||||
|
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
||||||
|
|
||||||
|
x = self.out(x)
|
||||||
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
|
x = self.out_whiten(x)
|
||||||
|
x = self.out_norm(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
x_lens = (x_lens - 7) // 2
|
||||||
|
assert x.size(1) == x_lens.max().item()
|
||||||
|
|
||||||
|
return x, x_lens
|
||||||
@ -64,6 +64,7 @@ from zipformer import Zipformer2
|
|||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
|
from subsampling import Conv2dSubsampling
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -525,29 +526,46 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def _to_int_tuple(s: str):
|
||||||
|
return tuple(map(int, s.split(',')))
|
||||||
|
|
||||||
|
|
||||||
|
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||||
|
# encoder_embed converts the input of shape (N, T, num_features)
|
||||||
|
# to the shape (N, (T - 7) // 2, encoder_dims).
|
||||||
|
# That is, it does two things simultaneously:
|
||||||
|
# (1) subsampling: T -> (T - 7) // 2
|
||||||
|
# (2) embedding: num_features -> encoder_dims
|
||||||
|
# In the normal configuration, we will downsample once more at the end
|
||||||
|
# by a factor of 2, and most of the encoder stacks will run at a lower
|
||||||
|
# sampling rate.
|
||||||
|
encoder_embed = Conv2dSubsampling(
|
||||||
|
in_channels=params.feature_dim,
|
||||||
|
out_channels=_to_int_tuple(params.encoder_dim)[0],
|
||||||
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1))
|
||||||
|
)
|
||||||
|
return encoder_embed
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Zipformer and Transformer
|
|
||||||
def to_int_tuple(s: str):
|
|
||||||
return tuple(map(int, s.split(',')))
|
|
||||||
encoder = Zipformer2(
|
encoder = Zipformer2(
|
||||||
num_features=params.feature_dim,
|
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
downsampling_factor=to_int_tuple(params.downsampling_factor),
|
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
||||||
num_encoder_layers=to_int_tuple(params.num_encoder_layers),
|
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||||
encoder_dim=to_int_tuple(params.encoder_dim),
|
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||||
encoder_unmasked_dim=to_int_tuple(params.encoder_unmasked_dim),
|
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
||||||
query_head_dim=to_int_tuple(params.query_head_dim),
|
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||||
pos_head_dim=to_int_tuple(params.pos_head_dim),
|
pos_head_dim=_to_int_tuple(params.pos_head_dim),
|
||||||
value_head_dim=to_int_tuple(params.value_head_dim),
|
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||||
pos_dim=params.pos_dim,
|
pos_dim=params.pos_dim,
|
||||||
num_heads=to_int_tuple(params.num_heads),
|
num_heads=_to_int_tuple(params.num_heads),
|
||||||
feedforward_dim=to_int_tuple(params.feedforward_dim),
|
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||||
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
|
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
warmup_batches=4000.0,
|
warmup_batches=4000.0,
|
||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
chunk_size=to_int_tuple(params.chunk_size),
|
chunk_size=_to_int_tuple(params.chunk_size),
|
||||||
left_context_frames=to_int_tuple(params.left_context_frames),
|
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -564,7 +582,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
encoder_dim=int(max(params.encoder_dim.split(','))),
|
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -573,11 +591,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = Transducer(
|
||||||
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
|||||||
@ -18,7 +18,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
import itertools
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
@ -28,13 +27,8 @@ from scaling import (
|
|||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
Dropout2,
|
Dropout2,
|
||||||
Dropout3,
|
|
||||||
SwooshL,
|
|
||||||
SwooshR,
|
|
||||||
ChunkCausalDepthwiseConv1d,
|
ChunkCausalDepthwiseConv1d,
|
||||||
ActivationDropoutAndLinear,
|
ActivationDropoutAndLinear,
|
||||||
ScaledConv1d,
|
|
||||||
ScaledConv2d,
|
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
Whiten,
|
Whiten,
|
||||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||||
@ -44,13 +38,9 @@ from scaling import (
|
|||||||
FloatLike,
|
FloatLike,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
convert_num_channels,
|
convert_num_channels,
|
||||||
ScaleGrad,
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
from icefall.dist import get_rank
|
|
||||||
|
|
||||||
|
|
||||||
class Zipformer2(EncoderInterface):
|
class Zipformer2(EncoderInterface):
|
||||||
"""
|
"""
|
||||||
@ -60,8 +50,6 @@ class Zipformer2(EncoderInterface):
|
|||||||
as downsampling_factor if they are single ints or one-element tuples. The length of
|
as downsampling_factor if they are single ints or one-element tuples. The length of
|
||||||
downsampling_factor defines the number of stacks.
|
downsampling_factor defines the number of stacks.
|
||||||
|
|
||||||
|
|
||||||
num_features (int): Number of input features, e.g. 40.
|
|
||||||
output_downsampling_factor (int): how much to downsample at the output. Note:
|
output_downsampling_factor (int): how much to downsample at the output. Note:
|
||||||
we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
|
we also downsample by a factor of 2 in the Conv2dSubsampling encoder.
|
||||||
You should probably leave this at 2.
|
You should probably leave this at 2.
|
||||||
@ -104,7 +92,6 @@ class Zipformer2(EncoderInterface):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
|
||||||
output_downsampling_factor: int = 2,
|
output_downsampling_factor: int = 2,
|
||||||
downsampling_factor: Tuple[int] = (2, 4),
|
downsampling_factor: Tuple[int] = (2, 4),
|
||||||
encoder_dim: Union[int, Tuple[int]] = 384,
|
encoder_dim: Union[int, Tuple[int]] = 384,
|
||||||
@ -140,7 +127,6 @@ class Zipformer2(EncoderInterface):
|
|||||||
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
|
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.num_features = num_features # int
|
|
||||||
self.output_downsampling_factor = output_downsampling_factor # int
|
self.output_downsampling_factor = output_downsampling_factor # int
|
||||||
self.downsampling_factor = downsampling_factor # tuple
|
self.downsampling_factor = downsampling_factor # tuple
|
||||||
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
||||||
@ -160,18 +146,6 @@ class Zipformer2(EncoderInterface):
|
|||||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
|
|
||||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
|
||||||
# to the shape (N, (T - 7) // 2, encoder_dims).
|
|
||||||
# That is, it does two things simultaneously:
|
|
||||||
# (1) subsampling: T -> (T - 7) // 2
|
|
||||||
# (2) embedding: num_features -> encoder_dims
|
|
||||||
# In the normal configuration, we will downsample once more at the end
|
|
||||||
# by a factor of 2, and most of the encoder stacks will run at a lower
|
|
||||||
# sampling rate.
|
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, encoder_dim[0],
|
|
||||||
dropout=dropout)
|
|
||||||
|
|
||||||
|
|
||||||
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
# each one will be Zipformer2Encoder or DownsampledZipformer2Encoder
|
||||||
encoders = []
|
encoders = []
|
||||||
|
|
||||||
@ -297,7 +271,9 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor,
|
self, x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
src_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -306,34 +282,15 @@ class Zipformer2(EncoderInterface):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
`x` before padding.
|
`x` before padding.
|
||||||
chunk_size: Number of frames per chunk (only set this if causal == True).
|
src_key_padding_mask:
|
||||||
Must divide all elements of downsampling_factor. At 50hz frame
|
The mask for padding, of shape (batch_size, seq_len); True means
|
||||||
rate, i.e. after encoder_embed. If not specified, no chunking.
|
masked position. May be None.
|
||||||
left_context_chunks: Number of left-context chunks for each chunk (affects
|
|
||||||
attention mask); only set this if chunk_size specified. If -1, there
|
|
||||||
is no limit on the left context. If not -1, require:
|
|
||||||
left_context_chunks * context_size >= downsampling_factor[i] *
|
|
||||||
cnn_module_kernel[i] // 2.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
||||||
- lengths, a tensor of shape (batch_size,) containing the number
|
- lengths, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `embeddings` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
|
||||||
|
|
||||||
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
lengths = (x_lens - 7) // 2
|
|
||||||
assert x.size(0) == lengths.max().item()
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
feature_masks = self.get_feature_masks(x)
|
feature_masks = self.get_feature_masks(x)
|
||||||
|
|
||||||
@ -379,9 +336,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
assert self.output_downsampling_factor == 2
|
assert self.output_downsampling_factor == 2
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
lengths = (lengths + 1) // 2
|
lengths = (x_lens + 1) // 2
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
|
||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
@ -700,12 +655,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask),
|
src_key_padding_mask=src_key_padding_mask),
|
||||||
float(self.conv_skip_rate))
|
float(self.conv_skip_rate))
|
||||||
|
|
||||||
|
|
||||||
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
|
src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)),
|
||||||
float(self.ff3_skip_rate))
|
float(self.ff3_skip_rate))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
src = self.balancer1(src)
|
src = self.balancer1(src)
|
||||||
src = self.norm(src)
|
src = self.norm(src)
|
||||||
|
|
||||||
@ -1686,244 +1638,13 @@ class ScalarMultiply(nn.Module):
|
|||||||
return x * self.scale
|
return x * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXt(nn.Module):
|
|
||||||
"""
|
|
||||||
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
|
||||||
"""
|
|
||||||
def __init__(self,
|
|
||||||
channels: int,
|
|
||||||
hidden_ratio: int = 3,
|
|
||||||
kernel_size: Tuple[int, int] = (7, 7),
|
|
||||||
layerdrop_rate: FloatLike = None):
|
|
||||||
super().__init__()
|
|
||||||
padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
|
||||||
hidden_channels = channels * hidden_ratio
|
|
||||||
if layerdrop_rate is None:
|
|
||||||
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
|
||||||
self.layerdrop_rate = layerdrop_rate
|
|
||||||
|
|
||||||
self.depthwise_conv = nn.Conv2d(
|
|
||||||
in_channels=channels,
|
|
||||||
out_channels=channels,
|
|
||||||
groups=channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
padding=padding)
|
|
||||||
|
|
||||||
self.pointwise_conv1 = nn.Conv2d(
|
|
||||||
in_channels=channels,
|
|
||||||
out_channels=hidden_channels,
|
|
||||||
kernel_size=1)
|
|
||||||
|
|
||||||
self.hidden_balancer = Balancer(hidden_channels,
|
|
||||||
channel_dim=1,
|
|
||||||
min_positive=0.3,
|
|
||||||
max_positive=1.0,
|
|
||||||
min_abs=0.75,
|
|
||||||
max_abs=5.0)
|
|
||||||
|
|
||||||
self.activation = SwooshL()
|
|
||||||
self.pointwise_conv2 = ScaledConv2d(
|
|
||||||
in_channels=hidden_channels,
|
|
||||||
out_channels=channels,
|
|
||||||
kernel_size=1,
|
|
||||||
initial_scale=0.01)
|
|
||||||
|
|
||||||
self.out_balancer = Balancer(
|
|
||||||
channels, channel_dim=1,
|
|
||||||
min_positive=0.4, max_positive=0.6,
|
|
||||||
min_abs=1.0, max_abs=6.0,
|
|
||||||
)
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
|
||||||
whitening_limit=5.0,
|
|
||||||
prob=(0.025, 0.25),
|
|
||||||
grad_scale=0.01)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
if torch.jit.is_scripting() or not self.training:
|
|
||||||
return self.forward_internal(x)
|
|
||||||
layerdrop_rate = float(self.layerdrop_rate)
|
|
||||||
|
|
||||||
if layerdrop_rate != 0.0:
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
|
||||||
else:
|
|
||||||
mask = None
|
|
||||||
# turns out this caching idea does not work with --world-size > 1
|
|
||||||
#return caching_eval(self.forward_internal, x, mask)
|
|
||||||
return self.forward_internal(x, mask)
|
|
||||||
|
|
||||||
|
|
||||||
def forward_internal(self,
|
|
||||||
x: Tensor,
|
|
||||||
layer_skip_mask: Optional[Tensor] = None) -> Tensor:
|
|
||||||
"""
|
|
||||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
|
||||||
|
|
||||||
The returned value has the same shape as x.
|
|
||||||
"""
|
|
||||||
bypass = x
|
|
||||||
x = self.depthwise_conv(x)
|
|
||||||
x = self.pointwise_conv1(x)
|
|
||||||
x = self.hidden_balancer(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
x = self.pointwise_conv2(x)
|
|
||||||
|
|
||||||
if layer_skip_mask is not None:
|
|
||||||
x = x * layer_skip_mask
|
|
||||||
|
|
||||||
x = bypass + x
|
|
||||||
x = self.out_balancer(x)
|
|
||||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
|
||||||
x = self.out_whiten(x)
|
|
||||||
x = x.transpose(1, 3) # (N, C, H, W)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
|
||||||
"""Convolutional 2D subsampling (to 1/2 length).
|
|
||||||
|
|
||||||
Convert an input of shape (N, T, idim) to an output
|
|
||||||
with shape (N, T', odim), where
|
|
||||||
T' = (T-3)//2 - 2 == (T-7)//2
|
|
||||||
|
|
||||||
It is based on
|
|
||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
layer1_channels: int = 8,
|
|
||||||
layer2_channels: int = 32,
|
|
||||||
layer3_channels: int = 128,
|
|
||||||
dropout: FloatLike = 0.1,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
in_channels:
|
|
||||||
Number of channels in. The input shape is (N, T, in_channels).
|
|
||||||
Caution: It requires: T >=7, in_channels >=7
|
|
||||||
out_channels
|
|
||||||
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
|
||||||
layer1_channels:
|
|
||||||
Number of channels in layer1
|
|
||||||
layer1_channels:
|
|
||||||
Number of channels in layer2
|
|
||||||
bottleneck:
|
|
||||||
bottleneck dimension for 1d squeeze-excite
|
|
||||||
"""
|
|
||||||
assert in_channels >= 7
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# The ScaleGrad module is there to prevent the gradients
|
|
||||||
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
|
||||||
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
|
||||||
# training. (The second one is necessary to stop its bias from getting
|
|
||||||
# a too-large gradient).
|
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=1,
|
|
||||||
out_channels=layer1_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
padding=(0, 1), # (time, freq)
|
|
||||||
),
|
|
||||||
ScaleGrad(0.2),
|
|
||||||
Balancer(layer1_channels,
|
|
||||||
channel_dim=1,
|
|
||||||
max_abs=1.0),
|
|
||||||
SwooshR(),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=layer1_channels,
|
|
||||||
out_channels=layer2_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
Balancer(layer2_channels,
|
|
||||||
channel_dim=1,
|
|
||||||
max_abs=4.0),
|
|
||||||
SwooshR(),
|
|
||||||
nn.Conv2d(
|
|
||||||
in_channels=layer2_channels,
|
|
||||||
out_channels=layer3_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=(1, 2), # (time, freq)
|
|
||||||
),
|
|
||||||
Balancer(layer3_channels,
|
|
||||||
channel_dim=1,
|
|
||||||
max_abs=4.0),
|
|
||||||
SwooshR(),
|
|
||||||
)
|
|
||||||
|
|
||||||
cur_width = (in_channels - 1) // 2
|
|
||||||
|
|
||||||
# just one convnext layer
|
|
||||||
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
|
||||||
|
|
||||||
out_width = (((in_channels - 1) // 2) - 1) // 2
|
|
||||||
|
|
||||||
self.out = nn.Linear(out_width * layer3_channels, out_channels)
|
|
||||||
# use a larger than normal grad_scale on this whitening module; there is
|
|
||||||
# only one such module, so there is not a concern about adding together
|
|
||||||
# many copies of this extra gradient term.
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
|
||||||
whitening_limit=_whitening_schedule(4.0),
|
|
||||||
prob=(0.025, 0.25),
|
|
||||||
grad_scale=0.02)
|
|
||||||
|
|
||||||
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
|
||||||
# getting large, there is an unnecessary degree of freedom.
|
|
||||||
self.out_norm = BiasNorm(out_channels)
|
|
||||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Subsample x.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
Its shape is (N, T, idim).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
|
||||||
"""
|
|
||||||
# On entry, x is (N, T, idim)
|
|
||||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
|
||||||
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
|
||||||
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
|
||||||
# gradients.
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.convnext(x)
|
|
||||||
|
|
||||||
# Now x is of shape (N, odim, ((T-3)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
|
||||||
b, c, t, f = x.size()
|
|
||||||
|
|
||||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
|
||||||
# now x: (N, ((T-1)//2 - 1))//2, out_width * layer3_channels))
|
|
||||||
|
|
||||||
x = self.out(x)
|
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
|
||||||
x = self.out_whiten(x)
|
|
||||||
x = self.out_norm(x)
|
|
||||||
x = self.dropout(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_zipformer_main(causal: bool = False):
|
def _test_zipformer_main(causal: bool = False):
|
||||||
feature_dim = 50
|
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
feature_dim = 50
|
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
|
|
||||||
c = Zipformer2(
|
c = Zipformer2(
|
||||||
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
|
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
chunk_size=(4,) if causal else (-1,),
|
chunk_size=(4,) if causal else (-1,),
|
||||||
left_context_frames=(64,)
|
left_context_frames=(64,)
|
||||||
@ -1932,19 +1653,18 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
seq_len = 20
|
seq_len = 20
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
f = c(
|
f = c(
|
||||||
torch.randn(batch_size, seq_len, feature_dim),
|
torch.randn(seq_len, batch_size, 64),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
)
|
)
|
||||||
f[0].sum().backward()
|
f[0].sum().backward()
|
||||||
c.eval()
|
c.eval()
|
||||||
f = c(
|
f = c(
|
||||||
torch.randn(batch_size, seq_len, feature_dim),
|
torch.randn(seq_len, batch_size, 64),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
)
|
)
|
||||||
f # to remove flake8 warnings
|
f # to remove flake8 warnings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user