Zipformer with Adam optimizer

This commit is contained in:
Han Zhu 2024-08-01 14:51:00 +08:00
parent 3b257dd5ae
commit db38ab044b
18 changed files with 7444 additions and 1 deletions

View File

@ -36,7 +36,8 @@ The following table lists the differences among them.
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA |
| `zipformer_lora` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA |
| `zipformer_adam` | Upgraded Zipformer | Embedding + Conv1d | Zipformer with Adam optimizer |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -554,6 +554,106 @@ for m in greedy_search modified_beam_search fast_beam_search; do
done
```
### zipformer (zipformer + pruned stateless transducer + adam optimizer)
See <https://github.com/k2-fsa/icefall/pull/1708> for more details.
[zipformer_adam](./zipformer_adam)
#### Non-streaming
##### normal-scaled model, number of model parameters: 65595219, i.e., 65.60 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/zhu-han/icefall-asr-librispeech-zipformer-adam-medium-2023-08-01>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|--------------------|
| greedy_search | 2.35 | 5.53 | --epoch 70 --avg 30 |
| modified_beam_search | 2.29 | 5.48 | --epoch 70 --avg 30 |
| fast_beam_search | 2.31 | 5.52 | --epoch 70 --avg 30 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer_adam/train.py \
--world-size 4 \
--num-epochs 70 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer_adam/exp \
--causal 0 \
--full-libri 1 \
--max-duration 1000
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode.py \
--epoch 70 \
--avg 30 \
--use-averaged-model 1 \
--exp-dir ./zipformer_adam/exp \
--max-duration 600 \
--decoding-method $m
done
```
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
##### large-scaled model, number of model parameters: 148514478, i.e., 148.5 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/zhu-han/icefall-asr-librispeech-zipformer-adam-large-2023-08-01>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|----------------------|------------|------------|--------------------|
| greedy_search | 2.27 | 5.25 | --epoch 70 --avg 20 |
| modified_beam_search | 2.23 | 5.17 | --epoch 70 --avg 20 |
| fast_beam_search | 2.24 | 5.2 | --epoch 70 --avg 20 |
The training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 70 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer_adam/exp-large \
--causal 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--full-libri 1 \
--max-duration 1000
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search fast_beam_search; do
./zipformer/decode.py \
--epoch 70 \
--avg 20 \
--exp-dir zipformer_adam/exp-large \
--max-duration 600 \
--causal 0 \
--decoding-method $m \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
done
```
### Zipformer CTC
#### [zipformer_ctc](./zipformer_ctc)

View File

@ -0,0 +1 @@
../transducer/asr_datamodule.py

View File

@ -0,0 +1 @@
../zipformer/attention_decoder.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,109 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
return embedding_out

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,513 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset.
(1) Export to torchscript model using torch.jit.script()
- For non-streaming model:
./zipformer_adam/export.py \
--exp-dir ./zipformer_adam/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 70 \
--avg 30 \
--jit 1
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("jit_script.pt")`.
Check ./jit_pretrained.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
- For streaming model:
./zipformer_adam/export.py \
--exp-dir ./zipformer_adam/exp \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 70 \
--avg 30 \
--jit 1
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
Check ./jit_pretrained_streaming.py for its usage.
Check https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
- For non-streaming model:
./zipformer_adam/export.py \
--exp-dir ./zipformer_adam/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 70 \
--avg 30
- For streaming model:
./zipformer_adam/export.py \
--exp-dir ./zipformer_adam/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 70 \
--avg 30
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
- For non-streaming model:
To use the generated file with `zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./zipformer_adam/decode.py \
--exp-dir ./zipformer_adam/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
- For streaming model:
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
# simulated streaming decoding
./zipformer_adam/decode.py \
--exp-dir ./zipformer_adam/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
- non-streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
- streaming model:
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
# You will find the pre-trained models in exp dir
"""
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named jit_script.pt.
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
class EncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Args:
features: (N, T, C)
feature_lengths: (N,)
"""
x, x_lens = self.encoder_embed(features, feature_lengths)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
class StreamingEncoderModel(nn.Module):
"""A wrapper for encoder and encoder_embed"""
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
self.pad_length = 7 + 2 * 3
self.encoder = encoder
self.encoder_embed = encoder_embed
def forward(
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""Streaming forward for encoder_embed and encoder.
Args:
features: (N, T, C)
feature_lengths: (N,)
states: a list of Tensors
Returns encoder outputs, output lengths, and updated states.
"""
chunk_size = self.chunk_size
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
# if torch.cuda.is_available():
# device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.sos_id = params.eos_id = token_table["<sos/eos>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
model.save(str(params.exp_dir / filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,66 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit

View File

@ -0,0 +1 @@
../zipformer/label_smoothing.py

View File

@ -0,0 +1,379 @@
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
class AsrModel(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
attention_decoder: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
use_attention_decoder: bool = False,
):
"""A joint CTC & Transducer ASR model.
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
use_attention_decoder:
Whether use attention-decoder head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = nn.Linear(
encoder_dim, vocab_size
)
self.simple_lm_proj = nn.Linear(
decoder_dim, vocab_size
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
self.use_attention_decoder = use_attention_decoder
if use_attention_decoder:
self.attention_decoder = attention_decoder
else:
assert attention_decoder is None
def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
device = x.device
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y.to(device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = torch.empty(0)
pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
if self.use_attention_decoder:
attention_decoder_loss = self.attention_decoder.calc_att_loss(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
ys=y.to(device),
ys_lens=y_lens.to(device),
)
else:
attention_decoder_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss

View File

@ -0,0 +1,136 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional, Union
from torch.optim import Optimizer
class LRScheduler(object):
"""
Base-class for learning rate schedulers where the learning-rate depends on both the
batch and the epoch.
"""
def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer
self.verbose = verbose
for group in optimizer.param_groups:
group.setdefault("base_lr", group["lr"])
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
self.epoch = 0
self.batch = 0
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {
"base_lrs": self.base_lrs,
"epoch": self.epoch,
"batch": self.batch,
}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_last_lr(self) -> List[float]:
"""Return last computed learning rate by current scheduler. Will be a list of float."""
return self._last_lr
def get_lr(self):
# Compute list of learning rates from self.epoch and self.batch and
# self.base_lrs; this must be overloaded by the user.
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
raise NotImplementedError
def step_batch(self, batch: Optional[int] = None) -> None:
# Step the batch index, or just set it. If `batch` is specified, it
# must be the batch index from the start of training, i.e. summed over
# all epochs.
# You can call this in any order; if you don't provide 'batch', it should
# of course be called once per batch.
if batch is not None:
self.batch = batch
else:
self.batch = self.batch + 1
self._set_lrs()
def step_epoch(self, epoch: Optional[int] = None):
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
# you should call this at the start of the epoch; if you don't provide the 'epoch'
# arg, you should call it at the end of the epoch.
if epoch is not None:
self.epoch = epoch
else:
self.epoch = self.epoch + 1
self._set_lrs()
def _set_lrs(self):
values = self.get_lr()
assert len(values) == len(self.optimizer.param_groups)
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group["lr"] = lr
self.print_lr(self.verbose, i, lr)
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
logging.warn(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
class Noam(LRScheduler):
"""
The LR scheduler proposed by Noam
Ref: "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
"""
def __init__(
self,
optimizer: Optimizer,
warmup_batches: Union[int, float] = 10000.0,
verbose: bool = False,
):
super().__init__(optimizer, verbose)
self.warmup_batches = warmup_batches
self.normalize = self.warmup_batches ** (-0.5)
def get_lr(self):
warmup_factor = 0 if self.batch == 0 else min(
self.batch ** (-0.5),
self.batch * self.warmup_batches ** (-1.5)
) / self.normalize
return [x * warmup_factor for x in self.base_lrs]

View File

@ -0,0 +1,904 @@
# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
max_value = torch.max(x, y)
diff = torch.abs(x - y)
return max_value + torch.log1p(torch.exp(-diff))
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
# a pull request on PyTorch GitHub.
#
# The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace()
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
# Caution(fangjun): Put torch.jit.is_scripting() before
# torch.onnx.is_in_onnx_export();
# otherwise, it will cause errors for torch.jit.script().
#
# torch.logaddexp() works for both torch.jit.script() and
# torch.jit.trace() but it causes errors for ONNX export.
#
if torch.jit.is_scripting():
# Note: We cannot use torch.jit.is_tracing() here as it also
# matches torch.onnx.export().
return torch.logaddexp(x, y)
elif torch.onnx.is_in_onnx_export():
return logaddexp_onnx(x, y)
else:
# for torch.jit.trace()
return torch.logaddexp(x, y)
class PiecewiseLinear(object):
"""
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
respectively.
"""
def __init__(self, *args):
assert len(args) >= 1, len(args)
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
self.pairs = list(args[0].pairs)
else:
self.pairs = [(float(x), float(y)) for x, y in args]
for x, y in self.pairs:
assert isinstance(x, (float, int)), type(x)
assert isinstance(y, (float, int)), type(y)
for i in range(len(self.pairs) - 1):
assert self.pairs[i + 1][0] > self.pairs[i][0], (
i,
self.pairs[i],
self.pairs[i + 1],
)
def __str__(self):
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
def __call__(self, x):
if x <= self.pairs[0][0]:
return self.pairs[0][1]
elif x >= self.pairs[-1][0]:
return self.pairs[-1][1]
else:
cur_x, cur_y = self.pairs[0]
for i in range(1, len(self.pairs)):
next_x, next_y = self.pairs[i]
if x >= cur_x and x <= next_x:
return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
cur_x, cur_y = next_x, next_y
assert False
def __mul__(self, alpha):
return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
def __add__(self, x):
if isinstance(x, (float, int)):
return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
s, x = self.get_common_basis(x)
return PiecewiseLinear(
*[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
)
def max(self, x):
if isinstance(x, (float, int)):
x = PiecewiseLinear((0, x))
s, x = self.get_common_basis(x, include_crossings=True)
return PiecewiseLinear(
*[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
)
def min(self, x):
if isinstance(x, float) or isinstance(x, int):
x = PiecewiseLinear((0, x))
s, x = self.get_common_basis(x, include_crossings=True)
return PiecewiseLinear(
*[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
)
def __eq__(self, other):
return self.pairs == other.pairs
def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
"""
Returns (self_mod, p_mod) which are equivalent piecewise linear
functions to self and p, but with the same x values.
p: the other piecewise linear function
include_crossings: if true, include in the x values positions
where the functions indicate by this and p cross.
"""
assert isinstance(p, PiecewiseLinear), type(p)
# get sorted x-values without repetition.
x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
y_vals1 = [self(x) for x in x_vals]
y_vals2 = [p(x) for x in x_vals]
if include_crossings:
extra_x_vals = []
for i in range(len(x_vals) - 1):
if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
# if the two lines in this subsegment potentially cross each other..
diff_cur = abs(y_vals1[i] - y_vals2[i])
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
# `pos`, between 0 and 1, gives the relative x position,
# with 0 being x_vals[i] and 1 being x_vals[i+1].
pos = diff_cur / (diff_cur + diff_next)
extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
extra_x_vals.append(extra_x_val)
if len(extra_x_vals) > 0:
x_vals = sorted(set(x_vals + extra_x_vals))
y_vals1 = [self(x) for x in x_vals]
y_vals2 = [p(x) for x in x_vals]
return (
PiecewiseLinear(*zip(x_vals, y_vals1)),
PiecewiseLinear(*zip(x_vals, y_vals2)),
)
class ScheduledFloat(torch.nn.Module):
"""
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
it does not have a working forward() function. You are supposed to cast it to float, as
in, float(parent_module.whatever), and use it as something like a dropout prob.
It is a floating point value whose value changes depending on the batch count of the
training loop. It is a piecewise linear function where you specify the (x,y) pairs
in sorted order on x; x corresponds to the batch index. For batch-index values before the
first x or after the last x, we just use the first or last y value.
Example:
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
`default` is used when self.batch_count is not set or not in training mode or in
torch.jit scripting mode.
"""
def __init__(self, *args, default: float = 0.0):
super().__init__()
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = None
self.name = None
self.default = default
self.schedule = PiecewiseLinear(*args)
def extra_repr(self) -> str:
return (
f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
)
def __float__(self):
batch_count = self.batch_count
if (
batch_count is None
or not self.training
or torch.jit.is_scripting()
or torch.jit.is_tracing()
):
return float(self.default)
else:
ans = self.schedule(self.batch_count)
if random.random() < 0.0002:
logging.info(
f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
)
return ans
def __add__(self, x):
if isinstance(x, float) or isinstance(x, int):
return ScheduledFloat(self.schedule + x, default=self.default)
else:
return ScheduledFloat(
self.schedule + x.schedule, default=self.default + x.default
)
def max(self, x):
if isinstance(x, float) or isinstance(x, int):
return ScheduledFloat(self.schedule.max(x), default=self.default)
else:
return ScheduledFloat(
self.schedule.max(x.schedule), default=max(self.default, x.default)
)
FloatLike = Union[float, ScheduledFloat]
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
"""
A randomized way of casting a floating point value to half precision.
"""
if x.dtype == torch.float16:
return x
x_abs = x.abs()
is_too_small = x_abs < min_abs
# for elements where is_too_small is true, random_val will contain +-min_abs with
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
# for those elements].
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
return torch.where(is_too_small, random_val, x).to(torch.float16)
class CutoffEstimator:
"""
Estimates cutoffs of an arbitrary numerical quantity such that a specified
proportion of items will be above the cutoff on average.
p is the proportion of items that should be above the cutoff.
"""
def __init__(self, p: float):
self.p = p
# total count of items
self.count = 0
# total count of items that were above the cutoff
self.count_above = 0
# initial cutoff value
self.cutoff = 0
def __call__(self, x: float) -> bool:
"""
Returns true if x is above the cutoff.
"""
ans = x > self.cutoff
self.count += 1
if ans:
self.count_above += 1
cur_p = self.count_above / self.count
delta_p = cur_p - self.p
if (delta_p > 0) == ans:
q = abs(delta_p)
self.cutoff = x * q + self.cutoff * (1 - q)
return ans
class SoftmaxFunction(torch.autograd.Function):
"""
Tries to handle half-precision derivatives in a randomized way that should
be more accurate for training than the default behavior.
"""
@staticmethod
def forward(ctx, x: Tensor, dim: int):
ans = x.softmax(dim=dim)
# if x dtype is float16, x.softmax() returns a float32 because
# (presumably) that op does not support float16, and autocast
# is enabled.
if torch.is_autocast_enabled():
ans = ans.to(torch.float16)
ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype
ctx.dim = dim
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
return x_grad, None
def softmax(x: Tensor, dim: int):
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
return x.softmax(dim=dim)
return SoftmaxFunction.apply(x, dim)
class MaxEigLimiterFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
coeffs: Tensor,
direction: Tensor,
channel_dim: int,
grad_scale: float,
) -> Tensor:
ctx.channel_dim = channel_dim
ctx.grad_scale = grad_scale
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
return x
@staticmethod
def backward(ctx, x_grad, *args):
with torch.enable_grad():
(x_orig, coeffs, new_direction) = ctx.saved_tensors
x_orig.requires_grad = True
num_channels = x_orig.shape[ctx.channel_dim]
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
new_direction.requires_grad = False
x = x - x.mean(dim=0)
x_var = (x**2).mean()
x_residual = x - coeffs * new_direction
x_residual_var = (x_residual**2).mean()
# `variance_proportion` is the proportion of the variance accounted for
# by the top eigen-direction. This is to be minimized.
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
variance_proportion.backward()
x_orig_grad = x_orig.grad
x_extra_grad = (
x_orig.grad
* ctx.grad_scale
* x_grad.norm()
/ (x_orig_grad.norm() + 1.0e-20)
)
return x_grad + x_extra_grad.detach(), None, None, None, None
class BiasNormFunction(torch.autograd.Function):
# This computes:
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
# return x * scales
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
# it can just store the returned value (chances are, this will also be needed for
# some other reason, related to the next operation, so we can save memory).
@staticmethod
def forward(
ctx,
x: Tensor,
bias: Tensor,
log_scale: Tensor,
channel_dim: int,
store_output_for_backprop: bool,
) -> Tensor:
assert bias.ndim == 1
if channel_dim < 0:
channel_dim = channel_dim + x.ndim
ctx.store_output_for_backprop = store_output_for_backprop
ctx.channel_dim = channel_dim
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
) * log_scale.exp()
ans = x * scales
ctx.save_for_backward(
ans.detach() if store_output_for_backprop else x,
scales.detach(),
bias.detach(),
log_scale.detach(),
)
return ans
@staticmethod
def backward(ctx, ans_grad: Tensor) -> Tensor:
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
if ctx.store_output_for_backprop:
x = ans_or_x / scales
else:
x = ans_or_x
x = x.detach()
x.requires_grad = True
bias.requires_grad = True
log_scale.requires_grad = True
with torch.enable_grad():
# recompute scales from x, bias and log_scale.
scales = (
torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
) * log_scale.exp()
ans = x * scales
ans.backward(gradient=ans_grad)
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
class BiasNorm(torch.nn.Module):
"""
This is intended to be a simpler, and hopefully cheaper, replacement for
LayerNorm. The observation this is based on, is that Transformer-type
networks, especially with pre-norm, sometimes seem to set one of the
feature dimensions to a large constant value (e.g. 50), which "defeats"
the LayerNorm because the output magnitude is then not strongly dependent
on the other (useful) features. Presumably the weight and bias of the
LayerNorm are required to allow it to do this.
Instead, we give the BiasNorm a trainable bias that it can use when
computing the scale for normalization. We also give it a (scalar)
trainable scale on the output.
Args:
num_channels: the number of channels, e.g. 512.
channel_dim: the axis/dimension corresponding to the channel,
interpreted as an offset from the input's ndim if negative.
This is NOT the num_channels; it should typically be one of
{-2, -1, 0, 1, 2, 3}.
log_scale: the initial log-scale that we multiply the output by; this
is learnable.
log_scale_min: FloatLike, minimum allowed value of log_scale
log_scale_max: FloatLike, maximum allowed value of log_scale
store_output_for_backprop: only possibly affects memory use; recommend
to set to True if you think the output of this module is more likely
than the input of this module to be required to be stored for the
backprop.
"""
def __init__(
self,
num_channels: int,
channel_dim: int = -1, # CAUTION: see documentation.
log_scale: float = 1.0,
log_scale_min: float = -1.5,
log_scale_max: float = 1.5,
store_output_for_backprop: bool = False,
) -> None:
super(BiasNorm, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.log_scale = nn.Parameter(torch.tensor(log_scale))
self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4))
self.log_scale_min = log_scale_min
self.log_scale_max = log_scale_max
self.store_output_for_backprop = store_output_for_backprop
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
if torch.jit.is_scripting() or torch.jit.is_tracing():
channel_dim = self.channel_dim
if channel_dim < 0:
channel_dim += x.ndim
bias = self.bias
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
) * self.log_scale.exp()
return x * scales
log_scale = limit_param_value(
self.log_scale,
min=float(self.log_scale_min),
max=float(self.log_scale_max),
training=self.training,
)
return BiasNormFunction.apply(
x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
)
class ChunkCausalDepthwiseConv1d(torch.nn.Module):
"""
Behaves like a depthwise 1d convolution, except that it is causal in
a chunkwise way, as if we had a block-triangular attention mask.
The chunk size is provided at test time (it should probably be
kept in sync with the attention mask).
This has a little more than twice the parameters of a conventional
depthwise conv1d module: we implement it by having one
depthwise convolution, of half the width, that is causal (via
right-padding); and one depthwise convolution that is applied only
within chunks, that we multiply by a scaling factor which depends
on the position within the chunk.
Args:
Accepts the standard args and kwargs that nn.Linear accepts
e.g. in_features, out_features, bias=False.
initial_scale: you can override this if you want to increase
or decrease the initial magnitude of the module's output
(affects the initialization of weight_scale and bias_scale).
Another option, if you want to do something like this, is
to re-initialize the parameters.
"""
def __init__(
self,
channels: int,
kernel_size: int,
initial_scale: float = 1.0,
bias: bool = True,
):
super().__init__()
assert kernel_size % 2 == 1
half_kernel_size = (kernel_size + 1) // 2
# will pad manually, on one side.
self.causal_conv = nn.Conv1d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=half_kernel_size,
padding=0,
bias=True,
)
self.chunkwise_conv = nn.Conv1d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
bias=bias,
)
# first row is correction factors added to the scale near the left edge of the chunk,
# second row is correction factors added to the scale near the right edge of the chunk,
# both of these are added to a default scale of 1.0.
self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
self.kernel_size = kernel_size
with torch.no_grad():
self.causal_conv.weight[:] *= initial_scale
self.chunkwise_conv.weight[:] *= initial_scale
if bias:
torch.nn.init.uniform_(
self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
)
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
"""Forward function.
Args:
x: a Tensor of shape (batch_size, channels, seq_len)
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
"""
(batch_size, num_channels, seq_len) = x.shape
# half_kernel_size = self.kernel_size + 1 // 2
# left_pad is half_kernel_size - 1 where half_kernel_size is the size used
# in the causal conv. It's the amount by which we must pad on the left,
# to make the convolution causal.
left_pad = self.kernel_size // 2
if chunk_size < 0 or chunk_size > seq_len:
chunk_size = seq_len
right_pad = -seq_len % chunk_size
x = torch.nn.functional.pad(x, (left_pad, right_pad))
x_causal = self.causal_conv(x[..., : left_pad + seq_len])
assert x_causal.shape == (batch_size, num_channels, seq_len)
x_chunk = x[..., left_pad:]
num_chunks = x_chunk.shape[2] // chunk_size
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
batch_size * num_chunks, num_channels, chunk_size
)
x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
chunk_scale = self._get_chunk_scale(chunk_size)
x_chunk = x_chunk * chunk_scale
x_chunk = x_chunk.reshape(
batch_size, num_chunks, num_channels, chunk_size
).permute(0, 2, 1, 3)
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
..., :seq_len
]
return x_chunk + x_causal
def _get_chunk_scale(self, chunk_size: int):
"""Returns tensor of shape (num_channels, chunk_size) that will be used to
scale the output of self.chunkwise_conv."""
left_edge = self.chunkwise_conv_scale[0]
right_edge = self.chunkwise_conv_scale[1]
if chunk_size < self.kernel_size:
left_edge = left_edge[:, :chunk_size]
right_edge = right_edge[:, -chunk_size:]
else:
t = chunk_size - self.kernel_size
channels = left_edge.shape[0]
pad = torch.zeros(
channels, t, device=left_edge.device, dtype=left_edge.dtype
)
left_edge = torch.cat((left_edge, pad), dim=-1)
right_edge = torch.cat((pad, right_edge), dim=-1)
return 1.0 + (left_edge + right_edge)
def streaming_forward(
self,
x: Tensor,
cache: Tensor,
) -> Tuple[Tensor, Tensor]:
"""Streaming Forward function.
Args:
x: a Tensor of shape (batch_size, channels, seq_len)
cache: cached left context of shape (batch_size, channels, left_pad)
"""
(batch_size, num_channels, seq_len) = x.shape
# left_pad is half_kernel_size - 1 where half_kernel_size is the size used
# in the causal conv. It's the amount by which we must pad on the left,
# to make the convolution causal.
left_pad = self.kernel_size // 2
# Pad cache
assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
x = torch.cat([cache, x], dim=2)
# Update cache
cache = x[..., -left_pad:]
x_causal = self.causal_conv(x)
assert x_causal.shape == (batch_size, num_channels, seq_len)
x_chunk = x[..., left_pad:]
x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
x_chunk = x_chunk * chunk_scale
return x_chunk + x_causal, cache
def penalize_abs_values_gt(
x: Tensor, limit: float, penalty: float, name: str = None
) -> Tensor:
"""
Returns x unmodified, but in backprop will put a penalty for the excess of
the absolute values of elements of x over the limit "limit". E.g. if
limit == 10.0, then if x has any values over 10 it will get a penalty.
Caution: the value of this penalty will be affected by grad scaling used
in automatic mixed precision training. For this reasons we use this,
it shouldn't really matter, or may even be helpful; we just use this
to disallow really implausible values of scores to be given to softmax.
The name is for randomly printed debug info.
"""
x_sign = x.sign()
over_limit = (x.abs() - limit) > 0
# The following is a memory efficient way to penalize the absolute values of
# x that's over the limit. (The memory efficiency comes when you think
# about which items torch needs to cache for the autograd, and which ones it
# can throw away). The numerical value of aux_loss as computed here will
# actually be larger than it should be, by limit * over_limit.sum(), but it
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
# limit).relu().
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
# note: we don't do sum() here on aux)_loss, but it's as if we had done
# sum() due to how with_loss() works.
x = with_loss(x, aux_loss, name)
# you must use x for something, or this will be ineffective.
return x
class WithLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, y: Tensor, name: str):
ctx.y_shape = y.shape
if random.random() < 0.002 and name is not None:
loss_sum = y.sum().item()
logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
return x
@staticmethod
def backward(ctx, ans_grad: Tensor):
return (
ans_grad,
torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
None,
)
def with_loss(x, y, name):
# returns x but adds y.sum() to the loss function.
return WithLoss.apply(x, y, name)
class ScaleGradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
ctx.alpha = alpha
return x
@staticmethod
def backward(ctx, grad: Tensor):
return grad * ctx.alpha, None
def scale_grad(x: Tensor, alpha: float):
return ScaleGradFunction.apply(x, alpha)
class ScaleGrad(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return x
return scale_grad(x, self.alpha)
class LimitParamValue(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, min: float, max: float):
ctx.save_for_backward(x)
assert max >= min
ctx.min = min
ctx.max = max
return x
@staticmethod
def backward(ctx, x_grad: Tensor):
(x,) = ctx.saved_tensors
# where x < ctx.min, ensure all grads are negative (this will tend to make
# x more positive).
x_grad = x_grad * torch.where(
torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
)
# where x > ctx.max, ensure all grads are positive (this will tend to make
# x more negative).
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
return x_grad, None, None
def limit_param_value(
x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
):
# You apply this to (typically) an nn.Parameter during training to ensure that its
# (elements mostly) stays within a supplied range. This is done by modifying the
# gradients in backprop.
# It's not necessary to do this on every batch: do it only some of the time,
# to save a little time.
if training and random.random() < prob:
return LimitParamValue.apply(x, min, max)
else:
return x
def _no_op(x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x
else:
# a no-op function that will have a node in the autograd graph,
# to avoid certain bugs relating to backward hooks
return x.chunk(1, dim=-1)[0]
class Identity(torch.nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return _no_op(x)
# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
class Dropout2(nn.Module):
def __init__(self, p: FloatLike):
super().__init__()
self.p = p
def forward(self, x: Tensor) -> Tensor:
return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
class MulForDropout3(torch.autograd.Function):
# returns (x * y * alpha) where alpha is a float and y doesn't require
# grad and is zero-or-one.
@staticmethod
@custom_fwd
def forward(ctx, x, y, alpha):
assert not y.requires_grad
ans = x * y * alpha
ctx.save_for_backward(ans)
ctx.alpha = alpha
return ans
@staticmethod
@custom_bwd
def backward(ctx, ans_grad):
(ans,) = ctx.saved_tensors
x_grad = ctx.alpha * ans_grad * (ans != 0)
return x_grad, None, None
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
# and it lets you choose one dimension to share the dropout mask over
class Dropout3(nn.Module):
def __init__(self, p: FloatLike, shared_dim: int):
super().__init__()
self.p = p
self.shared_dim = shared_dim
def forward(self, x: Tensor) -> Tensor:
p = float(self.p)
if not self.training or p == 0:
return _no_op(x)
scale = 1.0 / (1 - p)
rand_shape = list(x.shape)
rand_shape[self.shared_dim] = 1
mask = torch.rand(*rand_shape, device=x.device) > p
ans = MulForDropout3.apply(x, mask, scale)
return ans
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
if num_channels <= x.shape[-1]:
return x[..., :num_channels]
else:
shape = list(x.shape)
shape[-1] = num_channels - shape[-1]
zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
return torch.cat((x, zeros), dim=-1)
def _test_softmax():
a = torch.randn(2, 10, dtype=torch.float64)
b = a.clone()
a.requires_grad = True
b.requires_grad = True
a.softmax(dim=1)[:, 0].sum().backward()
print("a grad = ", a.grad)
softmax(b, dim=1)[:, 0].sum().backward()
print("b grad = ", b.grad)
assert torch.allclose(a.grad, b.grad)
def _test_piecewise_linear():
p = PiecewiseLinear((0, 10.0))
for x in [-100, 0, 100]:
assert p(x) == 10.0
p = PiecewiseLinear((0, 10.0), (1, 0.0))
for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
print("x, y = ", x, y)
assert p(x) == y, (x, p(x), y)
q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
pq = p.max(q)
for x in x_vals:
y1 = max(p(x), q(x))
y2 = pq(x)
assert abs(y1 - y2) < 0.001
pq = p.min(q)
for x in x_vals:
y1 = min(p(x), q(x))
y2 = pq(x)
assert abs(y1 - y2) < 0.001
pq = p + q
for x in x_vals:
y1 = p(x) + q(x)
y2 = pq(x)
assert abs(y1 - y2) < 0.001
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_piecewise_linear()
_test_softmax()

View File

@ -0,0 +1,93 @@
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file replaces various modules in a model.
Specifically, BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List
import torch
import torch.nn as nn
from scaling import (
Dropout3,
ScaleGrad,
)
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
is_onnx: bool = False,
):
"""
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
for name, m in model.named_modules():
if isinstance(m, (Dropout3, ScaleGrad)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -0,0 +1,357 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Tuple
import torch
from scaling import (
BiasNorm,
Dropout3,
FloatLike,
Optional,
ScaleGrad,
ScheduledFloat,
)
from torch import Tensor, nn
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(
self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None,
):
super().__init__()
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=self.padding,
)
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1
)
self.activation = nn.SiLU(inplace=True)
self.pointwise_conv2 = nn.Conv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
if x.requires_grad:
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = x.transpose(1, 3) # (N, C, H, W)
return x
def streaming_forward(
self,
x: Tensor,
cached_left_pad: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
Returns:
- The returned value has the same shape as x.
- Updated cached_left_pad.
"""
padding = self.padding
# The length without right padding for depth-wise conv
T = x.size(2) - padding[0]
bypass = x[:, :, :T, :]
# Pad left side
assert cached_left_pad.size(2) == padding[0], (
cached_left_pad.size(2),
padding[0],
)
x = torch.cat([cached_left_pad, x], dim=2)
# Update cached left padding
cached_left_pad = x[:, :, T : padding[0] + T, :]
# depthwise_conv
x = torch.nn.functional.conv2d(
x,
weight=self.depthwise_conv.weight,
bias=self.depthwise_conv.bias,
padding=(0, padding[1]),
groups=self.depthwise_conv.groups,
)
x = self.pointwise_conv1(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x, cached_left_pad
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
# The ScaleGrad module is there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScaleGrad(0.2),
nn.SiLU(),
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
nn.SiLU(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
nn.SiLU()
)
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
# (in_channels-3)//4
self.out_width = (((in_channels - 1) // 2) - 1) // 2
self.layer3_channels = layer3_channels
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, (T-7)//2, out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, (T-7)//2, odim)
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
return x, x_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
cached_left_pad: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
- updated cache
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# T' = (T-7)//2
x = self.conv(x)
# T' = (T-7)//2-3
x, cached_left_pad = self.convnext.streaming_forward(
x, cached_left_pad=cached_left_pad
)
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, T', out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, T', odim)
x = self.out_norm(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert self.convnext.padding[0] == 3
# The ConvNeXt module needs 3 frames of right padding after subsampling
x_lens = (x_lens - 7) // 2 - 3
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# The ConvNeXt module needs 3 frames of right padding after subsampling
assert self.convnext.padding[0] == 3
x_lens = (x_lens - 7) // 2 - 3
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
return x, x_lens, cached_left_pad
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Get initial states for Conv2dSubsampling module.
It is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
"""
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
device
)
return cached_embed_left_pad

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff