mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
update
This commit is contained in:
parent
77125064cb
commit
75c5389979
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/SSL/hubert/beam_search.py
Symbolic link
1
egs/librispeech/SSL/hubert/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/beam_search.py
|
@ -22,39 +22,39 @@
|
||||
Usage:
|
||||
|
||||
(1) ctc-decoding
|
||||
./zipformer/ctc_decode.py \
|
||||
./hubert/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--exp-dir ./hubert/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-decoding
|
||||
|
||||
(2) 1best
|
||||
./zipformer/ctc_decode.py \
|
||||
./hubert/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--exp-dir ./hubert/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.6 \
|
||||
--decoding-method 1best
|
||||
|
||||
(3) nbest
|
||||
./zipformer/ctc_decode.py \
|
||||
./hubert/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--exp-dir ./hubert/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.6 \
|
||||
--decoding-method nbest
|
||||
|
||||
(4) nbest-rescoring
|
||||
./zipformer/ctc_decode.py \
|
||||
./hubert/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--exp-dir ./hubert/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.6 \
|
||||
@ -63,10 +63,10 @@ Usage:
|
||||
--decoding-method nbest-rescoring
|
||||
|
||||
(5) whole-lattice-rescoring
|
||||
./zipformer/ctc_decode.py \
|
||||
./hubert/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--exp-dir ./hubert/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.6 \
|
||||
@ -164,7 +164,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
default="hubert/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -340,7 +340,7 @@ def decode_one_batch(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.causal:
|
||||
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||
# this seems to cause insertions at the end of the utterance if used with hubert.
|
||||
pad_len = 30
|
||||
feature_lens += pad_len
|
||||
feature = torch.nn.functional.pad(
|
||||
|
@ -92,9 +92,9 @@ class HubertAsrDataset(torch.utils.data.Dataset):
|
||||
feature_size=1,
|
||||
sampling_rate=16000,
|
||||
padding_side="right",
|
||||
padding_value=0.0,
|
||||
padding_value=0,
|
||||
do_normalize=True,
|
||||
return_attention_mask=True,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
|
||||
@ -148,7 +148,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
pass
|
||||
print(batch["audio"])
|
||||
print(batch["audio_lens"])
|
||||
print(batch["supervisions"]["text"])
|
||||
print(batch["cuts"])
|
||||
|
@ -1,134 +0,0 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scaling import Balancer
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
)
|
||||
# the balancers are to avoid any drift in the magnitude of the
|
||||
# embeddings, which would interact badly with parameter averaging.
|
||||
self.balancer = Balancer(
|
||||
decoder_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.5,
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
if context_size > 1:
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
self.balancer2 = Balancer(
|
||||
decoder_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.5,
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
self.balancer2 = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
|
||||
embedding_out = self.balancer(embedding_out)
|
||||
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
embedding_out = self.balancer2(embedding_out)
|
||||
|
||||
return embedding_out
|
1
egs/librispeech/SSL/hubert/decoder.py
Symbolic link
1
egs/librispeech/SSL/hubert/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/decoder.py
|
@ -64,7 +64,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import AsrModel
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -152,7 +151,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--do-stable-layer-norm",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--feat-extract-activation",
|
||||
@ -162,12 +161,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--feat-extract-norm",
|
||||
type=str,
|
||||
default="layer",
|
||||
default="group",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--feat-proj-dropout",
|
||||
type=float,
|
||||
default=0.0,
|
||||
default=0.1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--feat-proj-layer-norm",
|
||||
@ -192,7 +191,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--hidden-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
default=768,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initializer-range",
|
||||
@ -202,7 +201,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--intermediate-size",
|
||||
type=int,
|
||||
default=4096,
|
||||
default=3072,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layer-norm-eps",
|
||||
@ -247,7 +246,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-attention-heads",
|
||||
type=int,
|
||||
default=16,
|
||||
default=12,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-conv-pos-embedding-groups",
|
||||
@ -262,14 +261,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-hidden-layers",
|
||||
type=int,
|
||||
default=24,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Embedding dimension in encoder model.",
|
||||
default=12,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -366,6 +358,14 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained-dir",
|
||||
type=str,
|
||||
default="download/hubert-base-ls960",
|
||||
help="""The pretrained model dir.
|
||||
It specifies the directory where the pretrained checkpoint is saved.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
@ -657,7 +657,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
encoder_dim=params.encoder_dim,
|
||||
encoder_dim=params.hidden_size,
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -685,7 +685,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=params.encoder_dim,
|
||||
encoder_dim=params.hidden_size,
|
||||
decoder_dim=params.decoder_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
use_transducer=params.use_transducer,
|
||||
@ -731,6 +731,8 @@ def load_checkpoint_if_available(
|
||||
elif params.start_epoch > 1:
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
else:
|
||||
logging.info(f"Loading {params.pretrained_dir}")
|
||||
model.encoder = HubertModel.from_pretrained(params.pretrained_dir)
|
||||
return None
|
||||
|
||||
assert filename.is_file(), f"{filename} does not exist!"
|
||||
|
@ -1,67 +0,0 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim, (
|
||||
encoder_out.shape,
|
||||
decoder_out.shape,
|
||||
)
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return logit
|
1
egs/librispeech/SSL/hubert/joiner.py
Symbolic link
1
egs/librispeech/SSL/hubert/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/joiner.py
|
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/SSL/hubert/optim.py
Symbolic link
1
egs/librispeech/SSL/hubert/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/optim.py
|
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/SSL/hubert/scaling.py
Symbolic link
1
egs/librispeech/SSL/hubert/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/scaling.py
|
@ -1,406 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
Dropout3,
|
||||
FloatLike,
|
||||
Optional,
|
||||
ScaledConv2d,
|
||||
ScaleGrad,
|
||||
ScheduledFloat,
|
||||
SwooshL,
|
||||
SwooshR,
|
||||
Whiten,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class ConvNeXt(nn.Module):
|
||||
"""
|
||||
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
hidden_ratio: int = 3,
|
||||
kernel_size: Tuple[int, int] = (7, 7),
|
||||
layerdrop_rate: FloatLike = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
||||
hidden_channels = channels * hidden_ratio
|
||||
if layerdrop_rate is None:
|
||||
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
||||
self.layerdrop_rate = layerdrop_rate
|
||||
|
||||
self.depthwise_conv = nn.Conv2d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
groups=channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
self.pointwise_conv1 = nn.Conv2d(
|
||||
in_channels=channels, out_channels=hidden_channels, kernel_size=1
|
||||
)
|
||||
|
||||
self.hidden_balancer = Balancer(
|
||||
hidden_channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.3,
|
||||
max_positive=1.0,
|
||||
min_abs=0.75,
|
||||
max_abs=5.0,
|
||||
)
|
||||
|
||||
self.activation = SwooshL()
|
||||
self.pointwise_conv2 = ScaledConv2d(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
initial_scale=0.01,
|
||||
)
|
||||
|
||||
self.out_balancer = Balancer(
|
||||
channels,
|
||||
channel_dim=1,
|
||||
min_positive=0.4,
|
||||
max_positive=0.6,
|
||||
min_abs=1.0,
|
||||
max_abs=6.0,
|
||||
)
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=5.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return self.forward_internal(x)
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
|
||||
if layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = (
|
||||
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
> layerdrop_rate
|
||||
)
|
||||
else:
|
||||
mask = None
|
||||
# turns out this caching idea does not work with --world-size > 1
|
||||
# return caching_eval(self.forward_internal, x, mask)
|
||||
return self.forward_internal(x, mask)
|
||||
|
||||
def forward_internal(
|
||||
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||
|
||||
The returned value has the same shape as x.
|
||||
"""
|
||||
bypass = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.hidden_balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
if layer_skip_mask is not None:
|
||||
x = x * layer_skip_mask
|
||||
|
||||
x = bypass + x
|
||||
x = self.out_balancer(x)
|
||||
|
||||
if x.requires_grad:
|
||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||
x = self.out_whiten(x)
|
||||
x = x.transpose(1, 3) # (N, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
cached_left_pad: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
|
||||
|
||||
Returns:
|
||||
- The returned value has the same shape as x.
|
||||
- Updated cached_left_pad.
|
||||
"""
|
||||
padding = self.padding
|
||||
|
||||
# The length without right padding for depth-wise conv
|
||||
T = x.size(2) - padding[0]
|
||||
|
||||
bypass = x[:, :, :T, :]
|
||||
|
||||
# Pad left side
|
||||
assert cached_left_pad.size(2) == padding[0], (
|
||||
cached_left_pad.size(2),
|
||||
padding[0],
|
||||
)
|
||||
x = torch.cat([cached_left_pad, x], dim=2)
|
||||
# Update cached left padding
|
||||
cached_left_pad = x[:, :, T : padding[0] + T, :]
|
||||
|
||||
# depthwise_conv
|
||||
x = torch.nn.functional.conv2d(
|
||||
x,
|
||||
weight=self.depthwise_conv.weight,
|
||||
bias=self.depthwise_conv.bias,
|
||||
padding=(0, padding[1]),
|
||||
groups=self.depthwise_conv.groups,
|
||||
)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.hidden_balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
x = bypass + x
|
||||
return x, cached_left_pad
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = (T-3)//2 - 2 == (T-7)//2
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
dropout: FloatLike = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
bottleneck:
|
||||
bottleneck dimension for 1d squeeze-excite
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
# The ScaleGrad module is there to prevent the gradients
|
||||
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
||||
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
||||
# training. (The second one is necessary to stop its bias from getting
|
||||
# a too-large gradient).
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=(0, 1), # (time, freq)
|
||||
),
|
||||
ScaleGrad(0.2),
|
||||
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
|
||||
SwooshR(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
),
|
||||
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
|
||||
SwooshR(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2), # (time, freq)
|
||||
),
|
||||
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
|
||||
SwooshR(),
|
||||
)
|
||||
|
||||
# just one convnext layer
|
||||
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||
|
||||
# (in_channels-3)//4
|
||||
self.out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||
self.layer3_channels = layer3_channels
|
||||
|
||||
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
|
||||
# use a larger than normal grad_scale on this whitening module; there is
|
||||
# only one such module, so there is not a concern about adding together
|
||||
# many copies of this extra gradient term.
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.02,
|
||||
)
|
||||
|
||||
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
||||
# getting large, there is an unnecessary degree of freedom.
|
||||
self.out_norm = BiasNorm(out_channels)
|
||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
|
||||
Returns:
|
||||
- a tensor of shape (N, (T-7)//2, odim)
|
||||
- output lengths, of shape (batch_size,)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
||||
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
||||
# gradients.
|
||||
x = self.conv(x)
|
||||
x = self.convnext(x)
|
||||
|
||||
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, (T-7)//2, out_width * layer3_channels))
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, (T-7)//2, odim)
|
||||
x = self.out_whiten(x)
|
||||
x = self.out_norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
x_lens = (x_lens - 7) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
x_lens = (x_lens - 7) // 2
|
||||
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
|
||||
|
||||
return x, x_lens
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
cached_left_pad: Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
|
||||
Returns:
|
||||
- a tensor of shape (N, (T-7)//2, odim)
|
||||
- output lengths, of shape (batch_size,)
|
||||
- updated cache
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
|
||||
# T' = (T-7)//2
|
||||
x = self.conv(x)
|
||||
|
||||
# T' = (T-7)//2-3
|
||||
x, cached_left_pad = self.convnext.streaming_forward(
|
||||
x, cached_left_pad=cached_left_pad
|
||||
)
|
||||
|
||||
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, T', out_width * layer3_channels))
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, T', odim)
|
||||
x = self.out_norm(x)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
assert self.convnext.padding[0] == 3
|
||||
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
||||
x_lens = (x_lens - 7) // 2 - 3
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
||||
assert self.convnext.padding[0] == 3
|
||||
x_lens = (x_lens - 7) // 2 - 3
|
||||
|
||||
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
|
||||
|
||||
return x, x_lens, cached_left_pad
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_states(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> Tensor:
|
||||
"""Get initial states for Conv2dSubsampling module.
|
||||
It is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
"""
|
||||
left_pad = self.convnext.padding[0]
|
||||
freq = self.out_width
|
||||
channels = self.layer3_channels
|
||||
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
||||
device
|
||||
)
|
||||
|
||||
return cached_embed_left_pad
|
Loading…
x
Reference in New Issue
Block a user