mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
modified zipformer
This commit is contained in:
parent
f84270c935
commit
6072407669
@ -20,115 +20,162 @@ import torch.nn.functional as F
|
|||||||
from scaling import Balancer
|
from scaling import Balancer
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(torch.nn.Module):
|
||||||
"""This class modifies the stateless decoder from the following paper:
|
"""
|
||||||
|
This class modifies the stateless decoder from the following paper:
|
||||||
RNN-transducer with stateless prediction network
|
RNN-transducer with stateless prediction network
|
||||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||||
|
|
||||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
It removes the recurrent connection from the decoder, i.e., the prediction network.
|
||||||
network. Different from the above paper, it adds an extra Conv1d
|
Different from the above paper, it adds an extra Conv1d right after the embedding layer.
|
||||||
right after the embedding layer.
|
"""
|
||||||
|
|
||||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
def __init__(
|
||||||
|
self, vocab_size: int, decoder_dim: int, context_size: int, device: torch.device,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Decoder initialization.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
vocab_size : int
|
||||||
|
A number of tokens or modeling units, includes blank.
|
||||||
|
decoder_dim : int
|
||||||
|
A dimension of the decoder embeddings, and the decoder output.
|
||||||
|
context_size : int
|
||||||
|
A number of previous words to use to predict the next word.
|
||||||
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
|
device : torch.device
|
||||||
|
The device used to store the layer weights. Should be
|
||||||
|
either torch.device("cpu") or torch.device("cuda").
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embedding = torch.nn.Embedding(vocab_size, decoder_dim)
|
||||||
|
|
||||||
|
if context_size < 1:
|
||||||
|
raise ValueError(
|
||||||
|
'RNN-T decoder context size should be an integer greater '
|
||||||
|
f'or equal than 1, but got {context_size}.',
|
||||||
|
)
|
||||||
|
self.context_size = context_size
|
||||||
|
|
||||||
|
self.conv = torch.nn.Conv1d(
|
||||||
|
decoder_dim,
|
||||||
|
decoder_dim,
|
||||||
|
context_size,
|
||||||
|
groups=decoder_dim // 4,
|
||||||
|
bias=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
y : torch.Tensor[torch.int32]
|
||||||
|
The input integer tensor of shape (N, context_size).
|
||||||
|
The module input that corresponds to the last context_size decoded token indexes.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor[torch.float32]
|
||||||
|
An output float tensor of shape (N, 1, decoder_dim).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# this stuff about clamp() is a fix for a mismatch at utterance start,
|
||||||
|
# we use negative ids in RNN-T decoding.
|
||||||
|
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(2)
|
||||||
|
|
||||||
|
if self.context_size > 1:
|
||||||
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
embedding_out = self.conv(embedding_out)
|
||||||
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
|
embedding_out = torch.nn.functional.relu(embedding_out)
|
||||||
|
|
||||||
|
return embedding_out
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A helper module to combine decoder, decoder projection, and joiner inference together.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
decoder_dim: int,
|
decoder_dim: int,
|
||||||
blank_id: int,
|
joiner_dim: int,
|
||||||
context_size: int,
|
context_size: int,
|
||||||
):
|
beam: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
DecoderModule initialization.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
vocab_size:
|
vocab_size:
|
||||||
Number of tokens of the modeling unit including blank.
|
A number of tokens or modeling units, includes blank.
|
||||||
decoder_dim:
|
decoder_dim : int
|
||||||
Dimension of the input embedding, and of the decoder output.
|
A dimension of the decoder embeddings, and the decoder output.
|
||||||
blank_id:
|
joiner_dim : int
|
||||||
The ID of the blank symbol.
|
Input joiner dimension.
|
||||||
context_size:
|
context_size : int
|
||||||
Number of previous words to use to predict the next word.
|
A number of previous words to use to predict the next word.
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
|
beam : int
|
||||||
|
A decoder beam.
|
||||||
|
device : torch.device
|
||||||
|
The device used to store the layer weights. Should be
|
||||||
|
either torch.device("cpu") or torch.device("cuda").
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.embedding = nn.Embedding(
|
self.decoder = Decoder(vocab_size, decoder_dim, context_size, device)
|
||||||
num_embeddings=vocab_size,
|
self.decoder_proj = torch.nn.Linear(decoder_dim, joiner_dim, device=device)
|
||||||
embedding_dim=decoder_dim,
|
self.joiner = Joiner(joiner_dim, vocab_size, device)
|
||||||
)
|
|
||||||
# the balancers are to avoid any drift in the magnitude of the
|
|
||||||
# embeddings, which would interact badly with parameter averaging.
|
|
||||||
self.balancer = Balancer(
|
|
||||||
decoder_dim,
|
|
||||||
channel_dim=-1,
|
|
||||||
min_positive=0.0,
|
|
||||||
max_positive=1.0,
|
|
||||||
min_abs=0.5,
|
|
||||||
max_abs=1.0,
|
|
||||||
prob=0.05,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.blank_id = blank_id
|
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
|
||||||
self.context_size = context_size
|
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.beam = beam
|
||||||
|
|
||||||
if context_size > 1:
|
def forward(
|
||||||
self.conv = nn.Conv1d(
|
self, decoder_input: torch.Tensor, encoder_out: torch.Tensor, hyps_log_prob: torch.Tensor,
|
||||||
in_channels=decoder_dim,
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
out_channels=decoder_dim,
|
|
||||||
kernel_size=context_size,
|
|
||||||
padding=0,
|
|
||||||
groups=decoder_dim // 4, # group size == 4
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
self.balancer2 = Balancer(
|
|
||||||
decoder_dim,
|
|
||||||
channel_dim=-1,
|
|
||||||
min_positive=0.0,
|
|
||||||
max_positive=1.0,
|
|
||||||
min_abs=0.5,
|
|
||||||
max_abs=1.0,
|
|
||||||
prob=0.05,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
|
||||||
# when inference with torch.jit.script and context_size == 1
|
|
||||||
self.conv = nn.Identity()
|
|
||||||
self.balancer2 = nn.Identity()
|
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.
|
||||||
y:
|
|
||||||
A 2-D tensor of shape (N, U).
|
Parameters
|
||||||
need_pad:
|
----------
|
||||||
True to left pad the input. Should be True during training.
|
decoder_input : torch.Tensor[torch.int32]
|
||||||
False to not pad the input. Should be False during inference.
|
The input integer tensor of shape (num_hyps, context_size).
|
||||||
Returns:
|
The module input that corresponds to the last context_size decoded token indexes.
|
||||||
Return a tensor of shape (N, U, decoder_dim).
|
encoder_out : torch.Tensor[torch.float32]
|
||||||
|
An output tensor from the encoder after projection of shape (num_hyps, joiner_dim).
|
||||||
|
hyps_log_prob : torch.Tensor[torch.float32]
|
||||||
|
Hypothesis probabilities in a logarithmic scale of shape (num_hyps, 1).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor[torch.float32]
|
||||||
|
A float output tensor of logit token probabilities of shape (num_hyps, vocab_size).
|
||||||
"""
|
"""
|
||||||
y = y.to(torch.int64)
|
|
||||||
# this stuff about clamp() is a temporary fix for a mismatch
|
|
||||||
# at utterance start, we use negative ids in beam_search.py
|
|
||||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
|
||||||
|
|
||||||
embedding_out = self.balancer(embedding_out)
|
decoder_out = self.decoder(decoder_input)
|
||||||
|
decoder_out = self.decoder_proj(decoder_out)
|
||||||
|
|
||||||
if self.context_size > 1:
|
logits = self.joiner(encoder_out, decoder_out[:, 0, :])
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
|
||||||
if need_pad is True:
|
|
||||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
|
||||||
else:
|
|
||||||
# During inference time, there is no need to do extra padding
|
|
||||||
# as we only need one output
|
|
||||||
assert embedding_out.size(-1) == self.context_size
|
|
||||||
embedding_out = self.conv(embedding_out)
|
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
|
||||||
embedding_out = F.relu(embedding_out)
|
|
||||||
embedding_out = self.balancer2(embedding_out)
|
|
||||||
|
|
||||||
return embedding_out
|
tokens_log_prob = torch.log_softmax(logits, dim=1)
|
||||||
|
log_probs = (tokens_log_prob + hyps_log_prob).reshape(-1)
|
||||||
|
|
||||||
|
hyps_topk_log_prob, topk_indexes = log_probs.topk(self.beam)
|
||||||
|
topk_hyp_indexes = torch.floor_divide(topk_indexes, self.vocab_size).to(torch.int32)
|
||||||
|
topk_token_indexes = torch.remainder(topk_indexes, self.vocab_size).to(torch.int32)
|
||||||
|
tokens_topk_prob = torch.exp(tokens_log_prob.reshape(-1)[topk_indexes])
|
||||||
|
|
||||||
|
return hyps_topk_log_prob, tokens_topk_prob, topk_hyp_indexes, topk_token_indexes
|
@ -19,49 +19,42 @@ import torch.nn as nn
|
|||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(torch.nn.Module):
|
||||||
def __init__(
|
|
||||||
self,
|
def __init__(self, joiner_dim: int, vocab_size: int, device: torch.device) -> None:
|
||||||
encoder_dim: int,
|
"""
|
||||||
decoder_dim: int,
|
Joiner initialization.
|
||||||
joiner_dim: int,
|
|
||||||
vocab_size: int,
|
Parameters
|
||||||
):
|
----------
|
||||||
|
joiner_dim : int
|
||||||
|
Input joiner dimension.
|
||||||
|
vocab_size : int
|
||||||
|
Output joiner dimension, the vocabulary size, the number of BPEs of the model.
|
||||||
|
device : torch.device
|
||||||
|
The device used to store the layer weights. Should be
|
||||||
|
either torch.device("cpu") or torch.device("cuda").
|
||||||
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
self.output_linear = torch.nn.Linear(joiner_dim, vocab_size, device=device)
|
||||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
|
||||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
|
||||||
|
|
||||||
def forward(
|
def forward(self, encoder_out: torch.Tensor, decoder_out: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
decoder_out: torch.Tensor,
|
|
||||||
project_input: bool = True,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Does a forward pass of the Joiner module. Returns an output tensor after a simple joining.
|
||||||
encoder_out:
|
|
||||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
Parameters
|
||||||
decoder_out:
|
----------
|
||||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
encoder_out : torch.Tensor[torch.float32]
|
||||||
project_input:
|
An output tensor from the encoder after projection of shape (N, joiner_dim).
|
||||||
If true, apply input projections encoder_proj and decoder_proj.
|
decoder_out : torch.Tensor[torch.float32]
|
||||||
If this is false, it is the user's responsibility to do this
|
An output tensor from the decoder after projection of shape (N, joiner_dim).
|
||||||
manually.
|
|
||||||
Returns:
|
Returns
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
-------
|
||||||
|
torch.Tensor[torch.float32]
|
||||||
|
A float output tensor of log token probabilities of shape (N, vocab_size).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim, (
|
|
||||||
encoder_out.shape,
|
|
||||||
decoder_out.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
if project_input:
|
return self.output_linear(torch.tanh(encoder_out + decoder_out))
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
|
||||||
else:
|
|
||||||
logit = encoder_out + decoder_out
|
|
||||||
|
|
||||||
logit = self.output_linear(torch.tanh(logit))
|
|
||||||
|
|
||||||
return logit
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -36,371 +36,259 @@ from scaling import (
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXt(nn.Module):
|
class ConvNeXt(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
The simplified ConvNeXt module interpretation based on https://arxiv.org/pdf/2206.14747.pdf.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_channels: int, device: torch.device) -> None:
|
||||||
|
"""
|
||||||
|
ConvNeXt initialization.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_channels : int
|
||||||
|
The number of input and output channels for ConvNeXt module.
|
||||||
|
device : torch.device
|
||||||
|
The device used to store the layer weights.
|
||||||
|
Either torch.device("cpu") or torch.device("cuda").
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channels: int,
|
|
||||||
hidden_ratio: int = 3,
|
|
||||||
kernel_size: Tuple[int, int] = (7, 7),
|
|
||||||
layerdrop_rate: FloatLike = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
|
||||||
hidden_channels = channels * hidden_ratio
|
|
||||||
if layerdrop_rate is None:
|
|
||||||
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
|
||||||
self.layerdrop_rate = layerdrop_rate
|
|
||||||
|
|
||||||
self.depthwise_conv = nn.Conv2d(
|
self.padding = 3
|
||||||
in_channels=channels,
|
hidden_channels = num_channels * 3
|
||||||
out_channels=channels,
|
|
||||||
groups=channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
padding=self.padding,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pointwise_conv1 = nn.Conv2d(
|
self.depthwise_conv = torch.nn.Conv2d(
|
||||||
in_channels=channels, out_channels=hidden_channels, kernel_size=1
|
num_channels,
|
||||||
)
|
num_channels,
|
||||||
|
7,
|
||||||
self.hidden_balancer = Balancer(
|
groups=num_channels,
|
||||||
hidden_channels,
|
padding=(0, self.padding), # time, freq
|
||||||
channel_dim=1,
|
device=device,
|
||||||
min_positive=0.3,
|
|
||||||
max_positive=1.0,
|
|
||||||
min_abs=0.75,
|
|
||||||
max_abs=5.0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.activation = SwooshL()
|
self.activation = SwooshL()
|
||||||
self.pointwise_conv2 = ScaledConv2d(
|
self.pointwise_conv1 = torch.nn.Conv2d(num_channels, hidden_channels, 1, device=device)
|
||||||
in_channels=hidden_channels,
|
self.pointwise_conv2 = torch.nn.Conv2d(hidden_channels, num_channels, 1, device=device)
|
||||||
out_channels=channels,
|
|
||||||
kernel_size=1,
|
|
||||||
initial_scale=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.out_balancer = Balancer(
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
channels,
|
|
||||||
channel_dim=1,
|
|
||||||
min_positive=0.4,
|
|
||||||
max_positive=0.6,
|
|
||||||
min_abs=1.0,
|
|
||||||
max_abs=6.0,
|
|
||||||
)
|
|
||||||
self.out_whiten = Whiten(
|
|
||||||
num_groups=1,
|
|
||||||
whitening_limit=5.0,
|
|
||||||
prob=(0.025, 0.25),
|
|
||||||
grad_scale=0.01,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
|
||||||
return self.forward_internal(x)
|
|
||||||
layerdrop_rate = float(self.layerdrop_rate)
|
|
||||||
|
|
||||||
if layerdrop_rate != 0.0:
|
|
||||||
batch_size = x.shape[0]
|
|
||||||
mask = (
|
|
||||||
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
|
||||||
> layerdrop_rate
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
mask = None
|
|
||||||
# turns out this caching idea does not work with --world-size > 1
|
|
||||||
# return caching_eval(self.forward_internal, x, mask)
|
|
||||||
return self.forward_internal(x, mask)
|
|
||||||
|
|
||||||
def forward_internal(
|
|
||||||
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
"""
|
||||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
Does a forward pass of the ConvNeXt module.
|
||||||
|
|
||||||
The returned value has the same shape as x.
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor[torch.float32]
|
||||||
|
An input float tensor of shape (1, num_channels, num_input_frames, num_freqs).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor[torch.float32]
|
||||||
|
A output float tensor of the same shape as input,
|
||||||
|
(1, num_channels, num_output_frames, num_freqs).
|
||||||
"""
|
"""
|
||||||
bypass = x
|
|
||||||
|
bypass = x[:, :, self.padding: x.size(2) - self.padding]
|
||||||
|
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.pointwise_conv1(x)
|
x = self.pointwise_conv1(x)
|
||||||
x = self.hidden_balancer(x)
|
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.pointwise_conv2(x)
|
x = self.pointwise_conv2(x)
|
||||||
|
|
||||||
if layer_skip_mask is not None:
|
|
||||||
x = x * layer_skip_mask
|
|
||||||
|
|
||||||
x = bypass + x
|
x = bypass + x
|
||||||
x = self.out_balancer(x)
|
|
||||||
|
|
||||||
if x.requires_grad:
|
|
||||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
|
||||||
x = self.out_whiten(x)
|
|
||||||
x = x.transpose(1, 3) # (N, C, H, W)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def streaming_forward(
|
|
||||||
self,
|
class Conv2dSubsampling(torch.nn.Module):
|
||||||
x: Tensor,
|
|
||||||
cached_left_pad: Tensor,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Convolutional 2D subsampling module. It performs the prior subsampling
|
||||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
(four times subsampling along the frequency axis and two times - along the time axis),
|
||||||
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
|
and low-level descriptor feature extraction from the log mel feature input before passing
|
||||||
|
it to zipformer encoder.
|
||||||
Returns:
|
|
||||||
- The returned value has the same shape as x.
|
|
||||||
- Updated cached_left_pad.
|
|
||||||
"""
|
|
||||||
padding = self.padding
|
|
||||||
|
|
||||||
# The length without right padding for depth-wise conv
|
|
||||||
T = x.size(2) - padding[0]
|
|
||||||
|
|
||||||
bypass = x[:, :, :T, :]
|
|
||||||
|
|
||||||
# Pad left side
|
|
||||||
assert cached_left_pad.size(2) == padding[0], (
|
|
||||||
cached_left_pad.size(2),
|
|
||||||
padding[0],
|
|
||||||
)
|
|
||||||
x = torch.cat([cached_left_pad, x], dim=2)
|
|
||||||
# Update cached left padding
|
|
||||||
cached_left_pad = x[:, :, T : padding[0] + T, :]
|
|
||||||
|
|
||||||
# depthwise_conv
|
|
||||||
x = torch.nn.functional.conv2d(
|
|
||||||
x,
|
|
||||||
weight=self.depthwise_conv.weight,
|
|
||||||
bias=self.depthwise_conv.bias,
|
|
||||||
padding=(0, padding[1]),
|
|
||||||
groups=self.depthwise_conv.groups,
|
|
||||||
)
|
|
||||||
x = self.pointwise_conv1(x)
|
|
||||||
x = self.hidden_balancer(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
x = self.pointwise_conv2(x)
|
|
||||||
|
|
||||||
x = bypass + x
|
|
||||||
return x, cached_left_pad
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
|
||||||
"""Convolutional 2D subsampling (to 1/2 length).
|
|
||||||
|
|
||||||
Convert an input of shape (N, T, idim) to an output
|
|
||||||
with shape (N, T', odim), where
|
|
||||||
T' = (T-3)//2 - 2 == (T-7)//2
|
|
||||||
|
|
||||||
It is based on
|
|
||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int,
|
input_dim: int,
|
||||||
out_channels: int,
|
output_dim: int,
|
||||||
layer1_channels: int = 8,
|
layer1_channels: int,
|
||||||
layer2_channels: int = 32,
|
layer2_channels: int,
|
||||||
layer3_channels: int = 128,
|
layer3_channels: int,
|
||||||
dropout: FloatLike = 0.1,
|
right_context: int,
|
||||||
|
device: torch.device,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Conv2dSubsampling initialization.
|
||||||
in_channels:
|
|
||||||
Number of channels in. The input shape is (N, T, in_channels).
|
Parameters
|
||||||
Caution: It requires: T >=7, in_channels >=7
|
----------
|
||||||
out_channels
|
input_dim : int
|
||||||
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
The number of input channels. Corresponds to the
|
||||||
layer1_channels:
|
number of features in the input feature tensor.
|
||||||
Number of channels in layer1
|
output_dim : int
|
||||||
layer1_channels:
|
The number of output channels.
|
||||||
Number of channels in layer2
|
layer1_channels : int
|
||||||
bottleneck:
|
The number of output channels in the first Conv2d layer.
|
||||||
bottleneck dimension for 1d squeeze-excite
|
layer2_channels : int
|
||||||
|
The number of output channels in the second Conv2d layer.
|
||||||
|
layer3_channels : int
|
||||||
|
The number of output channels in the third Conv2d layer.
|
||||||
|
right_context: int
|
||||||
|
The look-ahead right context that is used to update the left cache.
|
||||||
|
device : torch.device
|
||||||
|
The device used to store the layer weights. Should be
|
||||||
|
either torch.device("cpu") or torch.device("cuda").
|
||||||
"""
|
"""
|
||||||
assert in_channels >= 7
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# The ScaleGrad module is there to prevent the gradients
|
if input_dim < 7:
|
||||||
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
raise ValueError(
|
||||||
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
'The input feature dimension of the Conv2dSubsampling layer, can not be less than '
|
||||||
# training. (The second one is necessary to stop its bias from getting
|
'seven, otherwise the frequency subsampling will result with an empty output. '
|
||||||
# a too-large gradient).
|
f'Expected input_dim to be at least 7 but got {input_dim}.',
|
||||||
|
)
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.right_context = right_context
|
||||||
nn.Conv2d(
|
|
||||||
|
# Assume batch size is 1 and the right padding is 10,
|
||||||
|
# see the forward method on why the right padding is 10.
|
||||||
|
self.right_pad = torch.full(
|
||||||
|
(1, 10, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
|
||||||
|
)
|
||||||
|
self.conv = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=layer1_channels,
|
out_channels=layer1_channels,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=(0, 1), # (time, freq)
|
padding=(0, 1), # (time, freq)
|
||||||
|
device=device,
|
||||||
),
|
),
|
||||||
ScaleGrad(0.2),
|
|
||||||
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
|
|
||||||
SwooshR(),
|
SwooshR(),
|
||||||
nn.Conv2d(
|
torch.nn.Conv2d(layer1_channels, layer2_channels, 3, stride=2, device=device),
|
||||||
in_channels=layer1_channels,
|
|
||||||
out_channels=layer2_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=0,
|
|
||||||
),
|
|
||||||
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
|
|
||||||
SwooshR(),
|
SwooshR(),
|
||||||
nn.Conv2d(
|
torch.nn.Conv2d(layer2_channels, layer3_channels, 3, stride=(1, 2), device=device),
|
||||||
in_channels=layer2_channels,
|
|
||||||
out_channels=layer3_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=(1, 2), # (time, freq)
|
|
||||||
),
|
|
||||||
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
|
|
||||||
SwooshR(),
|
SwooshR(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# just one convnext layer
|
self.convnext = ConvNeXt(layer3_channels, device=device)
|
||||||
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
|
||||||
|
|
||||||
# (in_channels-3)//4
|
out_width = (((input_dim - 1) // 2) - 1) // 2
|
||||||
self.out_width = (((in_channels - 1) // 2) - 1) // 2
|
self.out = torch.nn.Linear(out_width * layer3_channels, output_dim, device=device)
|
||||||
self.layer3_channels = layer3_channels
|
self.out_norm = BiasNorm(output_dim, device=device)
|
||||||
|
|
||||||
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
|
|
||||||
# use a larger than normal grad_scale on this whitening module; there is
|
|
||||||
# only one such module, so there is not a concern about adding together
|
|
||||||
# many copies of this extra gradient term.
|
|
||||||
self.out_whiten = Whiten(
|
|
||||||
num_groups=1,
|
|
||||||
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
|
||||||
prob=(0.025, 0.25),
|
|
||||||
grad_scale=0.02,
|
|
||||||
)
|
|
||||||
|
|
||||||
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
|
||||||
# getting large, there is an unnecessary degree of freedom.
|
|
||||||
self.out_norm = BiasNorm(out_channels)
|
|
||||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, cached_left_pad: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Subsample x.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
Its shape is (N, T, idim).
|
|
||||||
x_lens:
|
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- a tensor of shape (N, (T-7)//2, odim)
|
|
||||||
- output lengths, of shape (batch_size,)
|
|
||||||
"""
|
"""
|
||||||
# On entry, x is (N, T, idim)
|
Does a forward pass of the Conv2dSubsampling module.
|
||||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
|
||||||
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
Parameters
|
||||||
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
----------
|
||||||
# gradients.
|
x : torch.Tensor[torch.float32]
|
||||||
|
An input float tensor of shape (1, num_frames, input_dim). An input feature tensor.
|
||||||
|
cached_left_pad : torch.Tensor[torch.float32]
|
||||||
|
A left cache float tensor of shape (1, 10, input_dim). Left cache is required
|
||||||
|
to preserve the "same" left padding to the output of the Conv2dSubsampling module.
|
||||||
|
See the get_init_states() documentation to understand why we need exactly ten frames
|
||||||
|
of left padding for the Conv2dSubsampling module.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[torch.Tensor[torch.float32], torch.Tensor[torch.float32]]
|
||||||
|
A tuple of two float tensors:
|
||||||
|
- The processing output of the Conv2dSubsampling module
|
||||||
|
of shape (1, subsampled_num_frames, output_dim).
|
||||||
|
- The udated left cache tensor of shape (1, 10, input_dim).
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = torch.cat((cached_left_pad, x), dim=1)
|
||||||
|
new_cached_left_pad = x[
|
||||||
|
:,
|
||||||
|
x.size(1) - self.right_context - cached_left_pad.size(1):
|
||||||
|
x.size(1) - self.right_context,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Now when we concatenated the left cache with the input, we need to perform the right
|
||||||
|
# padding of the input in a way to preserve the "same" type of padding, so that the output
|
||||||
|
# of the module has the same duration as input (taking 2 times subsampling into account).
|
||||||
|
# There are two possible outcomes depending on whether the the number of input frames is
|
||||||
|
# even or odd, but both scenarios can be covered by 10 frames right padding.
|
||||||
|
|
||||||
|
# x : right padding
|
||||||
|
# | | | | | | | | | | | |:| | | | | | | | | | input
|
||||||
|
# | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
|
||||||
|
# | | | | | :| | | | second Conv2d output from self.conv
|
||||||
|
# | | | | :| | | third Conv2d output from self.conv
|
||||||
|
# | : Conv2d output from
|
||||||
|
# : self.convnext.depthwise_conv
|
||||||
|
# :
|
||||||
|
# x : right padding
|
||||||
|
# | | | | | | | | | | | | |:| | | | | | | | | | input
|
||||||
|
# | | | | | | | | | | | |:| | | | | | | | | first Conv2d output from self.conv
|
||||||
|
# | | | | | |: | | | | second Conv2d output from self.conv
|
||||||
|
# | | | | |: | | | third Conv2d output from self.conv
|
||||||
|
# | |: Conv2d output from
|
||||||
|
# : self.convnext.depthwise_conv
|
||||||
|
# :
|
||||||
|
|
||||||
|
x = torch.cat((x, self.right_pad), dim=1)
|
||||||
|
|
||||||
|
# (1, T, input_dim) -> (1, 1, T, input_dim) i.e., (N, C, H, W)
|
||||||
|
x = x.unsqueeze(1)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.convnext(x)
|
x = self.convnext(x)
|
||||||
|
|
||||||
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
|
# Now x is of shape (1, output_dim, T', ((input_dim - 1) // 2 - 1) // 2)
|
||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size() # b is equal to 1
|
||||||
|
x = x.permute(0, 2, 1, 3).reshape(b, t, c * f)
|
||||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
# Now x is of shape (T', output_dim * layer3_channels))
|
||||||
# now x: (N, (T-7)//2, out_width * layer3_channels))
|
|
||||||
|
|
||||||
x = self.out(x)
|
x = self.out(x)
|
||||||
# Now x is of shape (N, (T-7)//2, odim)
|
# Now x is of shape (T', output_dim)
|
||||||
x = self.out_whiten(x)
|
|
||||||
x = self.out_norm(x)
|
|
||||||
x = self.dropout(x)
|
|
||||||
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
|
||||||
x_lens = (x_lens - 7) // 2
|
|
||||||
else:
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
x_lens = (x_lens - 7) // 2
|
|
||||||
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
|
|
||||||
|
|
||||||
return x, x_lens
|
|
||||||
|
|
||||||
def streaming_forward(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
x_lens: torch.Tensor,
|
|
||||||
cached_left_pad: Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""Subsample x.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
Its shape is (N, T, idim).
|
|
||||||
x_lens:
|
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- a tensor of shape (N, (T-7)//2, odim)
|
|
||||||
- output lengths, of shape (batch_size,)
|
|
||||||
- updated cache
|
|
||||||
"""
|
|
||||||
# On entry, x is (N, T, idim)
|
|
||||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
|
||||||
|
|
||||||
# T' = (T-7)//2
|
|
||||||
x = self.conv(x)
|
|
||||||
|
|
||||||
# T' = (T-7)//2-3
|
|
||||||
x, cached_left_pad = self.convnext.streaming_forward(
|
|
||||||
x, cached_left_pad=cached_left_pad
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
|
|
||||||
b, c, t, f = x.size()
|
|
||||||
|
|
||||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
|
||||||
# now x: (N, T', out_width * layer3_channels))
|
|
||||||
|
|
||||||
x = self.out(x)
|
|
||||||
# Now x is of shape (N, T', odim)
|
|
||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
return x, new_cached_left_pad
|
||||||
assert self.convnext.padding[0] == 3
|
|
||||||
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
|
||||||
x_lens = (x_lens - 7) // 2 - 3
|
|
||||||
else:
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
|
||||||
assert self.convnext.padding[0] == 3
|
|
||||||
x_lens = (x_lens - 7) // 2 - 3
|
|
||||||
|
|
||||||
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
|
def get_init_states(input_dim: int, device: torch.device) -> torch.Tensor:
|
||||||
|
|
||||||
return x, x_lens, cached_left_pad
|
|
||||||
|
|
||||||
@torch.jit.export
|
|
||||||
def get_init_states(
|
|
||||||
self,
|
|
||||||
batch_size: int = 1,
|
|
||||||
device: torch.device = torch.device("cpu"),
|
|
||||||
) -> Tensor:
|
|
||||||
"""Get initial states for Conv2dSubsampling module.
|
|
||||||
It is the cached left padding for ConvNeXt module,
|
|
||||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
|
||||||
"""
|
"""
|
||||||
left_pad = self.convnext.padding[0]
|
Get initial states for Conv2dSubsampling module. The Conv2dSubsampling.conv consists of three
|
||||||
freq = self.out_width
|
consecutive Conv2d layers with the kernel size 3 and no padding, also the middle Conv2d
|
||||||
channels = self.layer3_channels
|
has a stride 2, while the rest have the default stride 1. We want to pad the input from the
|
||||||
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
left side with cached_left_pad in the "same" way, so when we pass it through
|
||||||
device
|
the Conv2dSubsampling.conv and Conv2dSubsampling.convnext we end up with exactly zero padding
|
||||||
|
frames from the left.
|
||||||
|
|
||||||
|
cached_left_pad : x
|
||||||
|
| | | | | | | | | |:| | | | | | | | | | | input
|
||||||
|
| | | | | | | | |:| | | | | | | | | | | first Conv2d output from Conv2dSubsampling.conv
|
||||||
|
| | | | :| | | | | | ... second Conv2d output from Conv2dSubsampling.conv
|
||||||
|
| | | :| | | | | | third Conv2d output from Conv2dSubsampling.conv
|
||||||
|
:| | | | | | Conv2d output from
|
||||||
|
: Conv2dSubsampling.convnext.depthwise_conv
|
||||||
|
|
||||||
|
As we can see from the picture above, in order to preserve the "same"
|
||||||
|
padding from the left side we need
|
||||||
|
((((pad - 1) - 1) // 2) - 1) - 3 = 0 --> pad = 10.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_dim : int
|
||||||
|
The number of input channels.
|
||||||
|
Corresponds to the number of features in the input of the Conv2dSubsampling module.
|
||||||
|
device : torch.device
|
||||||
|
The device used to store the left cache tensor.
|
||||||
|
Either torch.device("cpu") or torch.device("cuda").
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor[torch.float32]
|
||||||
|
A left cache float tensor. The output shape is (1, 10, input_dim).
|
||||||
|
"""
|
||||||
|
|
||||||
|
pad = 10
|
||||||
|
cached_left_pad = torch.full(
|
||||||
|
(1, pad, input_dim), ZERO_LOG_MEL, dtype=torch.float32, device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cached_embed_left_pad
|
return cached_left_pad
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user