mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Zipformer with Adam optimizer
This commit is contained in:
parent
3b257dd5ae
commit
db38ab044b
@ -36,7 +36,8 @@ The following table lists the differences among them.
|
||||
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
|
||||
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
|
||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
|
||||
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA |
|
||||
| `zipformer_lora` | Upgraded Zipformer | Embedding + Conv1d | Finetune Zipformer with LoRA |
|
||||
| `zipformer_adam` | Upgraded Zipformer | Embedding + Conv1d | Zipformer with Adam optimizer |
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
|
@ -554,6 +554,106 @@ for m in greedy_search modified_beam_search fast_beam_search; do
|
||||
done
|
||||
```
|
||||
|
||||
### zipformer (zipformer + pruned stateless transducer + adam optimizer)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1708> for more details.
|
||||
|
||||
[zipformer_adam](./zipformer_adam)
|
||||
|
||||
#### Non-streaming
|
||||
|
||||
##### normal-scaled model, number of model parameters: 65595219, i.e., 65.60 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/zhu-han/icefall-asr-librispeech-zipformer-adam-medium-2023-08-01>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|----------------------|------------|------------|--------------------|
|
||||
| greedy_search | 2.35 | 5.53 | --epoch 70 --avg 30 |
|
||||
| modified_beam_search | 2.29 | 5.48 | --epoch 70 --avg 30 |
|
||||
| fast_beam_search | 2.31 | 5.52 | --epoch 70 --avg 30 |
|
||||
|
||||
The training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./zipformer_adam/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 70 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer_adam/exp \
|
||||
--causal 0 \
|
||||
--full-libri 1 \
|
||||
--max-duration 1000
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in greedy_search modified_beam_search fast_beam_search; do
|
||||
./zipformer/decode.py \
|
||||
--epoch 70 \
|
||||
--avg 30 \
|
||||
--use-averaged-model 1 \
|
||||
--exp-dir ./zipformer_adam/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).
|
||||
|
||||
##### large-scaled model, number of model parameters: 148514478, i.e., 148.5 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/zhu-han/icefall-asr-librispeech-zipformer-adam-large-2023-08-01>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|----------------------|------------|------------|--------------------|
|
||||
| greedy_search | 2.27 | 5.25 | --epoch 70 --avg 20 |
|
||||
| modified_beam_search | 2.23 | 5.17 | --epoch 70 --avg 20 |
|
||||
| fast_beam_search | 2.24 | 5.2 | --epoch 70 --avg 20 |
|
||||
|
||||
The training command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 70 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer_adam/exp-large \
|
||||
--causal 0 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--full-libri 1 \
|
||||
--max-duration 1000
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in greedy_search modified_beam_search fast_beam_search; do
|
||||
./zipformer/decode.py \
|
||||
--epoch 70 \
|
||||
--avg 20 \
|
||||
--exp-dir zipformer_adam/exp-large \
|
||||
--max-duration 600 \
|
||||
--causal 0 \
|
||||
--decoding-method $m \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||
done
|
||||
```
|
||||
|
||||
### Zipformer CTC
|
||||
|
||||
#### [zipformer_ctc](./zipformer_ctc)
|
||||
|
1
egs/librispeech/ASR/zipformer_adam/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adam/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer/asr_datamodule.py
|
1
egs/librispeech/ASR/zipformer_adam/attention_decoder.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adam/attention_decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/attention_decoder.py
|
1
egs/librispeech/ASR/zipformer_adam/beam_search.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adam/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
1051
egs/librispeech/ASR/zipformer_adam/decode.py
Executable file
1051
egs/librispeech/ASR/zipformer_adam/decode.py
Executable file
File diff suppressed because it is too large
Load Diff
109
egs/librispeech/ASR/zipformer_adam/decoder.py
Normal file
109
egs/librispeech/ASR/zipformer_adam/decoder.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
if context_size > 1:
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
|
||||
# when inference with torch.jit.script and context_size == 1
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
|
||||
return embedding_out
|
1
egs/librispeech/ASR/zipformer_adam/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adam/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../transducer_stateless/encoder_interface.py
|
513
egs/librispeech/ASR/zipformer_adam/export.py
Executable file
513
egs/librispeech/ASR/zipformer_adam/export.py
Executable file
@ -0,0 +1,513 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
Note: This is a example for librispeech dataset, if you are using different
|
||||
dataset, you should change the argument values according to your dataset.
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer_adam/export.py \
|
||||
--exp-dir ./zipformer_adam/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 70 \
|
||||
--avg 30 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `jit_script.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("jit_script.pt")`.
|
||||
|
||||
Check ./jit_pretrained.py for its usage.
|
||||
|
||||
Check https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer_adam/export.py \
|
||||
--exp-dir ./zipformer_adam/exp \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 70 \
|
||||
--avg 30 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`.
|
||||
|
||||
Check ./jit_pretrained_streaming.py for its usage.
|
||||
|
||||
Check https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer_adam/export.py \
|
||||
--exp-dir ./zipformer_adam/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 70 \
|
||||
--avg 30
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer_adam/export.py \
|
||||
--exp-dir ./zipformer_adam/exp \
|
||||
--causal 1 \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 70 \
|
||||
--avg 30
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
To use the generated file with `zipformer/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./zipformer_adam/decode.py \
|
||||
--exp-dir ./zipformer_adam/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
- For streaming model:
|
||||
|
||||
To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
|
||||
# simulated streaming decoding
|
||||
./zipformer_adam/decode.py \
|
||||
--exp-dir ./zipformer_adam/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--causal 1 \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
- non-streaming model:
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
|
||||
- streaming model:
|
||||
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
# You will find the pre-trained models in exp dir
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from torch import Tensor, nn
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named jit_script.pt.
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class EncoderModel(nn.Module):
|
||||
"""A wrapper for encoder and encoder_embed"""
|
||||
|
||||
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
|
||||
def forward(
|
||||
self, features: Tensor, feature_lengths: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
features: (N, T, C)
|
||||
feature_lengths: (N,)
|
||||
"""
|
||||
x, x_lens = self.encoder_embed(features, feature_lengths)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
class StreamingEncoderModel(nn.Module):
|
||||
"""A wrapper for encoder and encoder_embed"""
|
||||
|
||||
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||
super().__init__()
|
||||
assert len(encoder.chunk_size) == 1, encoder.chunk_size
|
||||
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
|
||||
self.chunk_size = encoder.chunk_size[0]
|
||||
self.left_context_len = encoder.left_context_frames[0]
|
||||
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
self.pad_length = 7 + 2 * 3
|
||||
|
||||
self.encoder = encoder
|
||||
self.encoder_embed = encoder_embed
|
||||
|
||||
def forward(
|
||||
self, features: Tensor, feature_lengths: Tensor, states: List[Tensor]
|
||||
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
||||
"""Streaming forward for encoder_embed and encoder.
|
||||
|
||||
Args:
|
||||
features: (N, T, C)
|
||||
feature_lengths: (N,)
|
||||
states: a list of Tensors
|
||||
|
||||
Returns encoder outputs, output lengths, and updated states.
|
||||
"""
|
||||
chunk_size = self.chunk_size
|
||||
left_context_len = self.left_context_len
|
||||
|
||||
cached_embed_left_pad = states[-2]
|
||||
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
)
|
||||
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
encoder_states = states[:-2]
|
||||
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_encoder_states,
|
||||
) = self.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=encoder_states,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
new_states = new_encoder_states + [
|
||||
new_cached_embed_left_pad,
|
||||
new_processed_lens,
|
||||
]
|
||||
return encoder_out, encoder_out_lens, new_states
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_states(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||
states[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
"""
|
||||
states = self.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
# if torch.cuda.is_available():
|
||||
# device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.sos_id = params.eos_id = token_table["<sos/eos>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
|
||||
# Wrap encoder and encoder_embed as a module
|
||||
if params.causal:
|
||||
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
|
||||
chunk_size = model.encoder.chunk_size
|
||||
left_context_len = model.encoder.left_context_len
|
||||
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
|
||||
else:
|
||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||
filename = "jit_script.pt"
|
||||
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
model.save(str(params.exp_dir / filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
66
egs/librispeech/ASR/zipformer_adam/joiner.py
Normal file
66
egs/librispeech/ASR/zipformer_adam/joiner.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
|
||||
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim, (
|
||||
encoder_out.shape,
|
||||
decoder_out.shape,
|
||||
)
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return logit
|
1
egs/librispeech/ASR/zipformer_adam/label_smoothing.py
Symbolic link
1
egs/librispeech/ASR/zipformer_adam/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../zipformer/label_smoothing.py
|
379
egs/librispeech/ASR/zipformer_adam/model.py
Normal file
379
egs/librispeech/ASR/zipformer_adam/model.py
Normal file
@ -0,0 +1,379 @@
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
attention_decoder: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 384,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
use_attention_decoder: bool = False,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
It is used when use_transducer is True.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
It is used when use_transducer is True.
|
||||
use_transducer:
|
||||
Whether use transducer head. Default: True.
|
||||
use_ctc:
|
||||
Whether use CTC head. Default: False.
|
||||
use_attention_decoder:
|
||||
Whether use attention-decoder head. Default: False.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
use_transducer or use_ctc
|
||||
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
|
||||
self.use_transducer = use_transducer
|
||||
if use_transducer:
|
||||
# Modules for Transducer head
|
||||
assert decoder is not None
|
||||
assert hasattr(decoder, "blank_id")
|
||||
assert joiner is not None
|
||||
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = nn.Linear(
|
||||
encoder_dim, vocab_size
|
||||
)
|
||||
self.simple_lm_proj = nn.Linear(
|
||||
decoder_dim, vocab_size
|
||||
)
|
||||
else:
|
||||
assert decoder is None
|
||||
assert joiner is None
|
||||
|
||||
self.use_ctc = use_ctc
|
||||
if use_ctc:
|
||||
# Modules for CTC head
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
self.use_attention_decoder = use_attention_decoder
|
||||
if use_attention_decoder:
|
||||
self.attention_decoder = attention_decoder
|
||||
else:
|
||||
assert attention_decoder is None
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute encoder outputs.
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
|
||||
Returns:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
"""
|
||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def forward_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
||||
def forward_transducer(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
y_lens: torch.Tensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return simple_loss, pruned_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||
|
||||
device = x.device
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
if self.use_transducer:
|
||||
# Compute transducer loss
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y.to(device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
)
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
pruned_loss = torch.empty(0)
|
||||
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
|
||||
if self.use_attention_decoder:
|
||||
attention_decoder_loss = self.attention_decoder.calc_att_loss(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
ys=y.to(device),
|
||||
ys_lens=y_lens.to(device),
|
||||
)
|
||||
else:
|
||||
attention_decoder_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss
|
136
egs/librispeech/ASR/zipformer_adam/optim.py
Normal file
136
egs/librispeech/ASR/zipformer_adam/optim.py
Normal file
@ -0,0 +1,136 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
class LRScheduler(object):
|
||||
"""
|
||||
Base-class for learning rate schedulers where the learning-rate depends on both the
|
||||
batch and the epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
||||
# Attach optimizer
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
|
||||
self.optimizer = optimizer
|
||||
self.verbose = verbose
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("base_lr", group["lr"])
|
||||
|
||||
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
|
||||
|
||||
self.epoch = 0
|
||||
self.batch = 0
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a :class:`dict`.
|
||||
|
||||
It contains an entry for every variable in self.__dict__ which
|
||||
is not the optimizer.
|
||||
"""
|
||||
return {
|
||||
"base_lrs": self.base_lrs,
|
||||
"epoch": self.epoch,
|
||||
"batch": self.batch,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the schedulers state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): scheduler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def get_last_lr(self) -> List[float]:
|
||||
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
||||
return self._last_lr
|
||||
|
||||
def get_lr(self):
|
||||
# Compute list of learning rates from self.epoch and self.batch and
|
||||
# self.base_lrs; this must be overloaded by the user.
|
||||
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
||||
raise NotImplementedError
|
||||
|
||||
def step_batch(self, batch: Optional[int] = None) -> None:
|
||||
# Step the batch index, or just set it. If `batch` is specified, it
|
||||
# must be the batch index from the start of training, i.e. summed over
|
||||
# all epochs.
|
||||
# You can call this in any order; if you don't provide 'batch', it should
|
||||
# of course be called once per batch.
|
||||
if batch is not None:
|
||||
self.batch = batch
|
||||
else:
|
||||
self.batch = self.batch + 1
|
||||
self._set_lrs()
|
||||
|
||||
def step_epoch(self, epoch: Optional[int] = None):
|
||||
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
||||
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
||||
# arg, you should call it at the end of the epoch.
|
||||
if epoch is not None:
|
||||
self.epoch = epoch
|
||||
else:
|
||||
self.epoch = self.epoch + 1
|
||||
self._set_lrs()
|
||||
|
||||
def _set_lrs(self):
|
||||
values = self.get_lr()
|
||||
assert len(values) == len(self.optimizer.param_groups)
|
||||
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||
param_group, lr = data
|
||||
param_group["lr"] = lr
|
||||
self.print_lr(self.verbose, i, lr)
|
||||
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
||||
|
||||
def print_lr(self, is_verbose, group, lr):
|
||||
"""Display the current learning rate."""
|
||||
if is_verbose:
|
||||
logging.warn(
|
||||
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
||||
f" of group {group} to {lr:.4e}."
|
||||
)
|
||||
|
||||
class Noam(LRScheduler):
|
||||
"""
|
||||
The LR scheduler proposed by Noam
|
||||
Ref: "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
warmup_batches: Union[int, float] = 10000.0,
|
||||
verbose: bool = False,
|
||||
):
|
||||
super().__init__(optimizer, verbose)
|
||||
self.warmup_batches = warmup_batches
|
||||
self.normalize = self.warmup_batches ** (-0.5)
|
||||
|
||||
def get_lr(self):
|
||||
warmup_factor = 0 if self.batch == 0 else min(
|
||||
self.batch ** (-0.5),
|
||||
self.batch * self.warmup_batches ** (-1.5)
|
||||
) / self.normalize
|
||||
|
||||
return [x * warmup_factor for x in self.base_lrs]
|
904
egs/librispeech/ASR/zipformer_adam/scaling.py
Normal file
904
egs/librispeech/ASR/zipformer_adam/scaling.py
Normal file
@ -0,0 +1,904 @@
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Daniel Povey)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
|
||||
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
||||
max_value = torch.max(x, y)
|
||||
diff = torch.abs(x - y)
|
||||
return max_value + torch.log1p(torch.exp(-diff))
|
||||
|
||||
|
||||
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
||||
# 14 is not supported. Please feel free to request support or submit
|
||||
# a pull request on PyTorch GitHub.
|
||||
#
|
||||
# The following function is to solve the above error when exporting
|
||||
# models to ONNX via torch.jit.trace()
|
||||
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||
# Caution(fangjun): Put torch.jit.is_scripting() before
|
||||
# torch.onnx.is_in_onnx_export();
|
||||
# otherwise, it will cause errors for torch.jit.script().
|
||||
#
|
||||
# torch.logaddexp() works for both torch.jit.script() and
|
||||
# torch.jit.trace() but it causes errors for ONNX export.
|
||||
#
|
||||
if torch.jit.is_scripting():
|
||||
# Note: We cannot use torch.jit.is_tracing() here as it also
|
||||
# matches torch.onnx.export().
|
||||
return torch.logaddexp(x, y)
|
||||
elif torch.onnx.is_in_onnx_export():
|
||||
return logaddexp_onnx(x, y)
|
||||
else:
|
||||
# for torch.jit.trace()
|
||||
return torch.logaddexp(x, y)
|
||||
|
||||
|
||||
class PiecewiseLinear(object):
|
||||
"""
|
||||
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
||||
the x values in order. x values <[initial x] or >[final x] are map to [initial y], [final y]
|
||||
respectively.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
assert len(args) >= 1, len(args)
|
||||
if len(args) == 1 and isinstance(args[0], PiecewiseLinear):
|
||||
self.pairs = list(args[0].pairs)
|
||||
else:
|
||||
self.pairs = [(float(x), float(y)) for x, y in args]
|
||||
for x, y in self.pairs:
|
||||
assert isinstance(x, (float, int)), type(x)
|
||||
assert isinstance(y, (float, int)), type(y)
|
||||
|
||||
for i in range(len(self.pairs) - 1):
|
||||
assert self.pairs[i + 1][0] > self.pairs[i][0], (
|
||||
i,
|
||||
self.pairs[i],
|
||||
self.pairs[i + 1],
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
# e.g. 'PiecewiseLinear((0., 10.), (100., 0.))'
|
||||
return f"PiecewiseLinear({str(self.pairs)[1:-1]})"
|
||||
|
||||
def __call__(self, x):
|
||||
if x <= self.pairs[0][0]:
|
||||
return self.pairs[0][1]
|
||||
elif x >= self.pairs[-1][0]:
|
||||
return self.pairs[-1][1]
|
||||
else:
|
||||
cur_x, cur_y = self.pairs[0]
|
||||
for i in range(1, len(self.pairs)):
|
||||
next_x, next_y = self.pairs[i]
|
||||
if x >= cur_x and x <= next_x:
|
||||
return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
def __mul__(self, alpha):
|
||||
return PiecewiseLinear(*[(x, y * alpha) for x, y in self.pairs])
|
||||
|
||||
def __add__(self, x):
|
||||
if isinstance(x, (float, int)):
|
||||
return PiecewiseLinear(*[(p[0], p[1] + x) for p in self.pairs])
|
||||
s, x = self.get_common_basis(x)
|
||||
return PiecewiseLinear(
|
||||
*[(sp[0], sp[1] + xp[1]) for sp, xp in zip(s.pairs, x.pairs)]
|
||||
)
|
||||
|
||||
def max(self, x):
|
||||
if isinstance(x, (float, int)):
|
||||
x = PiecewiseLinear((0, x))
|
||||
s, x = self.get_common_basis(x, include_crossings=True)
|
||||
return PiecewiseLinear(
|
||||
*[(sp[0], max(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
|
||||
)
|
||||
|
||||
def min(self, x):
|
||||
if isinstance(x, float) or isinstance(x, int):
|
||||
x = PiecewiseLinear((0, x))
|
||||
s, x = self.get_common_basis(x, include_crossings=True)
|
||||
return PiecewiseLinear(
|
||||
*[(sp[0], min(sp[1], xp[1])) for sp, xp in zip(s.pairs, x.pairs)]
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.pairs == other.pairs
|
||||
|
||||
def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
|
||||
"""
|
||||
Returns (self_mod, p_mod) which are equivalent piecewise linear
|
||||
functions to self and p, but with the same x values.
|
||||
|
||||
p: the other piecewise linear function
|
||||
include_crossings: if true, include in the x values positions
|
||||
where the functions indicate by this and p cross.
|
||||
"""
|
||||
assert isinstance(p, PiecewiseLinear), type(p)
|
||||
|
||||
# get sorted x-values without repetition.
|
||||
x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
|
||||
y_vals1 = [self(x) for x in x_vals]
|
||||
y_vals2 = [p(x) for x in x_vals]
|
||||
|
||||
if include_crossings:
|
||||
extra_x_vals = []
|
||||
for i in range(len(x_vals) - 1):
|
||||
if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
|
||||
# if the two lines in this subsegment potentially cross each other..
|
||||
diff_cur = abs(y_vals1[i] - y_vals2[i])
|
||||
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
|
||||
# `pos`, between 0 and 1, gives the relative x position,
|
||||
# with 0 being x_vals[i] and 1 being x_vals[i+1].
|
||||
pos = diff_cur / (diff_cur + diff_next)
|
||||
extra_x_val = x_vals[i] + pos * (x_vals[i + 1] - x_vals[i])
|
||||
extra_x_vals.append(extra_x_val)
|
||||
if len(extra_x_vals) > 0:
|
||||
x_vals = sorted(set(x_vals + extra_x_vals))
|
||||
y_vals1 = [self(x) for x in x_vals]
|
||||
y_vals2 = [p(x) for x in x_vals]
|
||||
return (
|
||||
PiecewiseLinear(*zip(x_vals, y_vals1)),
|
||||
PiecewiseLinear(*zip(x_vals, y_vals2)),
|
||||
)
|
||||
|
||||
|
||||
class ScheduledFloat(torch.nn.Module):
|
||||
"""
|
||||
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
|
||||
it does not have a working forward() function. You are supposed to cast it to float, as
|
||||
in, float(parent_module.whatever), and use it as something like a dropout prob.
|
||||
|
||||
It is a floating point value whose value changes depending on the batch count of the
|
||||
training loop. It is a piecewise linear function where you specify the (x,y) pairs
|
||||
in sorted order on x; x corresponds to the batch index. For batch-index values before the
|
||||
first x or after the last x, we just use the first or last y value.
|
||||
|
||||
Example:
|
||||
self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0)
|
||||
|
||||
`default` is used when self.batch_count is not set or not in training mode or in
|
||||
torch.jit scripting mode.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, default: float = 0.0):
|
||||
super().__init__()
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = None
|
||||
self.name = None
|
||||
self.default = default
|
||||
self.schedule = PiecewiseLinear(*args)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
|
||||
)
|
||||
|
||||
def __float__(self):
|
||||
batch_count = self.batch_count
|
||||
if (
|
||||
batch_count is None
|
||||
or not self.training
|
||||
or torch.jit.is_scripting()
|
||||
or torch.jit.is_tracing()
|
||||
):
|
||||
return float(self.default)
|
||||
else:
|
||||
ans = self.schedule(self.batch_count)
|
||||
if random.random() < 0.0002:
|
||||
logging.info(
|
||||
f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}"
|
||||
)
|
||||
return ans
|
||||
|
||||
def __add__(self, x):
|
||||
if isinstance(x, float) or isinstance(x, int):
|
||||
return ScheduledFloat(self.schedule + x, default=self.default)
|
||||
else:
|
||||
return ScheduledFloat(
|
||||
self.schedule + x.schedule, default=self.default + x.default
|
||||
)
|
||||
|
||||
def max(self, x):
|
||||
if isinstance(x, float) or isinstance(x, int):
|
||||
return ScheduledFloat(self.schedule.max(x), default=self.default)
|
||||
else:
|
||||
return ScheduledFloat(
|
||||
self.schedule.max(x.schedule), default=max(self.default, x.default)
|
||||
)
|
||||
|
||||
|
||||
FloatLike = Union[float, ScheduledFloat]
|
||||
|
||||
|
||||
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
||||
"""
|
||||
A randomized way of casting a floating point value to half precision.
|
||||
"""
|
||||
if x.dtype == torch.float16:
|
||||
return x
|
||||
x_abs = x.abs()
|
||||
is_too_small = x_abs < min_abs
|
||||
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
||||
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
||||
# for those elements].
|
||||
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
||||
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
||||
|
||||
|
||||
class CutoffEstimator:
|
||||
"""
|
||||
Estimates cutoffs of an arbitrary numerical quantity such that a specified
|
||||
proportion of items will be above the cutoff on average.
|
||||
|
||||
p is the proportion of items that should be above the cutoff.
|
||||
"""
|
||||
|
||||
def __init__(self, p: float):
|
||||
self.p = p
|
||||
# total count of items
|
||||
self.count = 0
|
||||
# total count of items that were above the cutoff
|
||||
self.count_above = 0
|
||||
# initial cutoff value
|
||||
self.cutoff = 0
|
||||
|
||||
def __call__(self, x: float) -> bool:
|
||||
"""
|
||||
Returns true if x is above the cutoff.
|
||||
"""
|
||||
ans = x > self.cutoff
|
||||
self.count += 1
|
||||
if ans:
|
||||
self.count_above += 1
|
||||
cur_p = self.count_above / self.count
|
||||
delta_p = cur_p - self.p
|
||||
if (delta_p > 0) == ans:
|
||||
q = abs(delta_p)
|
||||
self.cutoff = x * q + self.cutoff * (1 - q)
|
||||
return ans
|
||||
|
||||
|
||||
class SoftmaxFunction(torch.autograd.Function):
|
||||
"""
|
||||
Tries to handle half-precision derivatives in a randomized way that should
|
||||
be more accurate for training than the default behavior.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, dim: int):
|
||||
ans = x.softmax(dim=dim)
|
||||
# if x dtype is float16, x.softmax() returns a float32 because
|
||||
# (presumably) that op does not support float16, and autocast
|
||||
# is enabled.
|
||||
if torch.is_autocast_enabled():
|
||||
ans = ans.to(torch.float16)
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.x_dtype = x.dtype
|
||||
ctx.dim = dim
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
(ans,) = ctx.saved_tensors
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
return x_grad, None
|
||||
|
||||
|
||||
def softmax(x: Tensor, dim: int):
|
||||
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x.softmax(dim=dim)
|
||||
|
||||
return SoftmaxFunction.apply(x, dim)
|
||||
|
||||
|
||||
class MaxEigLimiterFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
coeffs: Tensor,
|
||||
direction: Tensor,
|
||||
channel_dim: int,
|
||||
grad_scale: float,
|
||||
) -> Tensor:
|
||||
ctx.channel_dim = channel_dim
|
||||
ctx.grad_scale = grad_scale
|
||||
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad, *args):
|
||||
with torch.enable_grad():
|
||||
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
||||
x_orig.requires_grad = True
|
||||
num_channels = x_orig.shape[ctx.channel_dim]
|
||||
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
||||
new_direction.requires_grad = False
|
||||
x = x - x.mean(dim=0)
|
||||
x_var = (x**2).mean()
|
||||
x_residual = x - coeffs * new_direction
|
||||
x_residual_var = (x_residual**2).mean()
|
||||
# `variance_proportion` is the proportion of the variance accounted for
|
||||
# by the top eigen-direction. This is to be minimized.
|
||||
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
||||
variance_proportion.backward()
|
||||
x_orig_grad = x_orig.grad
|
||||
x_extra_grad = (
|
||||
x_orig.grad
|
||||
* ctx.grad_scale
|
||||
* x_grad.norm()
|
||||
/ (x_orig_grad.norm() + 1.0e-20)
|
||||
)
|
||||
return x_grad + x_extra_grad.detach(), None, None, None, None
|
||||
|
||||
|
||||
class BiasNormFunction(torch.autograd.Function):
|
||||
# This computes:
|
||||
# scales = (torch.mean((x - bias) ** 2, keepdim=True)) ** -0.5 * log_scale.exp()
|
||||
# return x * scales
|
||||
# (after unsqueezing the bias), but it does it in a memory-efficient way so that
|
||||
# it can just store the returned value (chances are, this will also be needed for
|
||||
# some other reason, related to the next operation, so we can save memory).
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
bias: Tensor,
|
||||
log_scale: Tensor,
|
||||
channel_dim: int,
|
||||
store_output_for_backprop: bool,
|
||||
) -> Tensor:
|
||||
assert bias.ndim == 1
|
||||
if channel_dim < 0:
|
||||
channel_dim = channel_dim + x.ndim
|
||||
ctx.store_output_for_backprop = store_output_for_backprop
|
||||
ctx.channel_dim = channel_dim
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (
|
||||
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
|
||||
) * log_scale.exp()
|
||||
ans = x * scales
|
||||
ctx.save_for_backward(
|
||||
ans.detach() if store_output_for_backprop else x,
|
||||
scales.detach(),
|
||||
bias.detach(),
|
||||
log_scale.detach(),
|
||||
)
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
||||
ans_or_x, scales, bias, log_scale = ctx.saved_tensors
|
||||
if ctx.store_output_for_backprop:
|
||||
x = ans_or_x / scales
|
||||
else:
|
||||
x = ans_or_x
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
bias.requires_grad = True
|
||||
log_scale.requires_grad = True
|
||||
with torch.enable_grad():
|
||||
# recompute scales from x, bias and log_scale.
|
||||
scales = (
|
||||
torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
|
||||
) * log_scale.exp()
|
||||
ans = x * scales
|
||||
ans.backward(gradient=ans_grad)
|
||||
return x.grad, bias.grad.flatten(), log_scale.grad, None, None
|
||||
|
||||
|
||||
class BiasNorm(torch.nn.Module):
|
||||
"""
|
||||
This is intended to be a simpler, and hopefully cheaper, replacement for
|
||||
LayerNorm. The observation this is based on, is that Transformer-type
|
||||
networks, especially with pre-norm, sometimes seem to set one of the
|
||||
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
||||
the LayerNorm because the output magnitude is then not strongly dependent
|
||||
on the other (useful) features. Presumably the weight and bias of the
|
||||
LayerNorm are required to allow it to do this.
|
||||
|
||||
Instead, we give the BiasNorm a trainable bias that it can use when
|
||||
computing the scale for normalization. We also give it a (scalar)
|
||||
trainable scale on the output.
|
||||
|
||||
|
||||
Args:
|
||||
num_channels: the number of channels, e.g. 512.
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interpreted as an offset from the input's ndim if negative.
|
||||
This is NOT the num_channels; it should typically be one of
|
||||
{-2, -1, 0, 1, 2, 3}.
|
||||
log_scale: the initial log-scale that we multiply the output by; this
|
||||
is learnable.
|
||||
log_scale_min: FloatLike, minimum allowed value of log_scale
|
||||
log_scale_max: FloatLike, maximum allowed value of log_scale
|
||||
store_output_for_backprop: only possibly affects memory use; recommend
|
||||
to set to True if you think the output of this module is more likely
|
||||
than the input of this module to be required to be stored for the
|
||||
backprop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
log_scale: float = 1.0,
|
||||
log_scale_min: float = -1.5,
|
||||
log_scale_max: float = 1.5,
|
||||
store_output_for_backprop: bool = False,
|
||||
) -> None:
|
||||
super(BiasNorm, self).__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.log_scale = nn.Parameter(torch.tensor(log_scale))
|
||||
self.bias = nn.Parameter(torch.empty(num_channels).normal_(mean=0, std=1e-4))
|
||||
|
||||
self.log_scale_min = log_scale_min
|
||||
self.log_scale_max = log_scale_max
|
||||
|
||||
self.store_output_for_backprop = store_output_for_backprop
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
channel_dim = self.channel_dim
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
bias = self.bias
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (
|
||||
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
|
||||
) * self.log_scale.exp()
|
||||
return x * scales
|
||||
|
||||
log_scale = limit_param_value(
|
||||
self.log_scale,
|
||||
min=float(self.log_scale_min),
|
||||
max=float(self.log_scale_max),
|
||||
training=self.training,
|
||||
)
|
||||
|
||||
return BiasNormFunction.apply(
|
||||
x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
|
||||
)
|
||||
|
||||
|
||||
class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
||||
"""
|
||||
Behaves like a depthwise 1d convolution, except that it is causal in
|
||||
a chunkwise way, as if we had a block-triangular attention mask.
|
||||
The chunk size is provided at test time (it should probably be
|
||||
kept in sync with the attention mask).
|
||||
|
||||
This has a little more than twice the parameters of a conventional
|
||||
depthwise conv1d module: we implement it by having one
|
||||
depthwise convolution, of half the width, that is causal (via
|
||||
right-padding); and one depthwise convolution that is applied only
|
||||
within chunks, that we multiply by a scaling factor which depends
|
||||
on the position within the chunk.
|
||||
|
||||
Args:
|
||||
Accepts the standard args and kwargs that nn.Linear accepts
|
||||
e.g. in_features, out_features, bias=False.
|
||||
|
||||
initial_scale: you can override this if you want to increase
|
||||
or decrease the initial magnitude of the module's output
|
||||
(affects the initialization of weight_scale and bias_scale).
|
||||
Another option, if you want to do something like this, is
|
||||
to re-initialize the parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
initial_scale: float = 1.0,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
|
||||
half_kernel_size = (kernel_size + 1) // 2
|
||||
# will pad manually, on one side.
|
||||
self.causal_conv = nn.Conv1d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
groups=channels,
|
||||
kernel_size=half_kernel_size,
|
||||
padding=0,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.chunkwise_conv = nn.Conv1d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
groups=channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
# first row is correction factors added to the scale near the left edge of the chunk,
|
||||
# second row is correction factors added to the scale near the right edge of the chunk,
|
||||
# both of these are added to a default scale of 1.0.
|
||||
self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
with torch.no_grad():
|
||||
self.causal_conv.weight[:] *= initial_scale
|
||||
self.chunkwise_conv.weight[:] *= initial_scale
|
||||
if bias:
|
||||
torch.nn.init.uniform_(
|
||||
self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x: a Tensor of shape (batch_size, channels, seq_len)
|
||||
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
|
||||
"""
|
||||
(batch_size, num_channels, seq_len) = x.shape
|
||||
|
||||
# half_kernel_size = self.kernel_size + 1 // 2
|
||||
# left_pad is half_kernel_size - 1 where half_kernel_size is the size used
|
||||
# in the causal conv. It's the amount by which we must pad on the left,
|
||||
# to make the convolution causal.
|
||||
left_pad = self.kernel_size // 2
|
||||
|
||||
if chunk_size < 0 or chunk_size > seq_len:
|
||||
chunk_size = seq_len
|
||||
right_pad = -seq_len % chunk_size
|
||||
|
||||
x = torch.nn.functional.pad(x, (left_pad, right_pad))
|
||||
|
||||
x_causal = self.causal_conv(x[..., : left_pad + seq_len])
|
||||
assert x_causal.shape == (batch_size, num_channels, seq_len)
|
||||
|
||||
x_chunk = x[..., left_pad:]
|
||||
num_chunks = x_chunk.shape[2] // chunk_size
|
||||
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
|
||||
x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
|
||||
batch_size * num_chunks, num_channels, chunk_size
|
||||
)
|
||||
x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
|
||||
|
||||
chunk_scale = self._get_chunk_scale(chunk_size)
|
||||
|
||||
x_chunk = x_chunk * chunk_scale
|
||||
x_chunk = x_chunk.reshape(
|
||||
batch_size, num_chunks, num_channels, chunk_size
|
||||
).permute(0, 2, 1, 3)
|
||||
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
|
||||
..., :seq_len
|
||||
]
|
||||
|
||||
return x_chunk + x_causal
|
||||
|
||||
def _get_chunk_scale(self, chunk_size: int):
|
||||
"""Returns tensor of shape (num_channels, chunk_size) that will be used to
|
||||
scale the output of self.chunkwise_conv."""
|
||||
left_edge = self.chunkwise_conv_scale[0]
|
||||
right_edge = self.chunkwise_conv_scale[1]
|
||||
if chunk_size < self.kernel_size:
|
||||
left_edge = left_edge[:, :chunk_size]
|
||||
right_edge = right_edge[:, -chunk_size:]
|
||||
else:
|
||||
t = chunk_size - self.kernel_size
|
||||
channels = left_edge.shape[0]
|
||||
pad = torch.zeros(
|
||||
channels, t, device=left_edge.device, dtype=left_edge.dtype
|
||||
)
|
||||
left_edge = torch.cat((left_edge, pad), dim=-1)
|
||||
right_edge = torch.cat((pad, right_edge), dim=-1)
|
||||
return 1.0 + (left_edge + right_edge)
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
cache: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Streaming Forward function.
|
||||
|
||||
Args:
|
||||
x: a Tensor of shape (batch_size, channels, seq_len)
|
||||
cache: cached left context of shape (batch_size, channels, left_pad)
|
||||
"""
|
||||
(batch_size, num_channels, seq_len) = x.shape
|
||||
|
||||
# left_pad is half_kernel_size - 1 where half_kernel_size is the size used
|
||||
# in the causal conv. It's the amount by which we must pad on the left,
|
||||
# to make the convolution causal.
|
||||
left_pad = self.kernel_size // 2
|
||||
|
||||
# Pad cache
|
||||
assert cache.shape[-1] == left_pad, (cache.shape[-1], left_pad)
|
||||
x = torch.cat([cache, x], dim=2)
|
||||
# Update cache
|
||||
cache = x[..., -left_pad:]
|
||||
|
||||
x_causal = self.causal_conv(x)
|
||||
assert x_causal.shape == (batch_size, num_channels, seq_len)
|
||||
|
||||
x_chunk = x[..., left_pad:]
|
||||
x_chunk = self.chunkwise_conv(x_chunk) # does not change shape
|
||||
|
||||
chunk_scale = self._get_chunk_scale(chunk_size=seq_len)
|
||||
x_chunk = x_chunk * chunk_scale
|
||||
|
||||
return x_chunk + x_causal, cache
|
||||
|
||||
|
||||
def penalize_abs_values_gt(
|
||||
x: Tensor, limit: float, penalty: float, name: str = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
Returns x unmodified, but in backprop will put a penalty for the excess of
|
||||
the absolute values of elements of x over the limit "limit". E.g. if
|
||||
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
||||
|
||||
Caution: the value of this penalty will be affected by grad scaling used
|
||||
in automatic mixed precision training. For this reasons we use this,
|
||||
it shouldn't really matter, or may even be helpful; we just use this
|
||||
to disallow really implausible values of scores to be given to softmax.
|
||||
|
||||
The name is for randomly printed debug info.
|
||||
"""
|
||||
x_sign = x.sign()
|
||||
over_limit = (x.abs() - limit) > 0
|
||||
# The following is a memory efficient way to penalize the absolute values of
|
||||
# x that's over the limit. (The memory efficiency comes when you think
|
||||
# about which items torch needs to cache for the autograd, and which ones it
|
||||
# can throw away). The numerical value of aux_loss as computed here will
|
||||
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
||||
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
||||
# limit).relu().
|
||||
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
||||
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
||||
# sum() due to how with_loss() works.
|
||||
x = with_loss(x, aux_loss, name)
|
||||
# you must use x for something, or this will be ineffective.
|
||||
return x
|
||||
|
||||
|
||||
class WithLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, y: Tensor, name: str):
|
||||
ctx.y_shape = y.shape
|
||||
if random.random() < 0.002 and name is not None:
|
||||
loss_sum = y.sum().item()
|
||||
logging.info(f"WithLoss: name={name}, loss-sum={loss_sum:.3e}")
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
return (
|
||||
ans_grad,
|
||||
torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def with_loss(x, y, name):
|
||||
# returns x but adds y.sum() to the loss function.
|
||||
return WithLoss.apply(x, y, name)
|
||||
|
||||
|
||||
class ScaleGradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
||||
ctx.alpha = alpha
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad: Tensor):
|
||||
return grad * ctx.alpha, None
|
||||
|
||||
|
||||
def scale_grad(x: Tensor, alpha: float):
|
||||
return ScaleGradFunction.apply(x, alpha)
|
||||
|
||||
|
||||
class ScaleGrad(nn.Module):
|
||||
def __init__(self, alpha: float):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return x
|
||||
return scale_grad(x, self.alpha)
|
||||
|
||||
|
||||
class LimitParamValue(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, min: float, max: float):
|
||||
ctx.save_for_backward(x)
|
||||
assert max >= min
|
||||
ctx.min = min
|
||||
ctx.max = max
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor):
|
||||
(x,) = ctx.saved_tensors
|
||||
# where x < ctx.min, ensure all grads are negative (this will tend to make
|
||||
# x more positive).
|
||||
x_grad = x_grad * torch.where(
|
||||
torch.logical_and(x_grad > 0, x < ctx.min), -1.0, 1.0
|
||||
)
|
||||
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
||||
# x more negative).
|
||||
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
||||
return x_grad, None, None
|
||||
|
||||
|
||||
def limit_param_value(
|
||||
x: Tensor, min: float, max: float, prob: float = 0.6, training: bool = True
|
||||
):
|
||||
# You apply this to (typically) an nn.Parameter during training to ensure that its
|
||||
# (elements mostly) stays within a supplied range. This is done by modifying the
|
||||
# gradients in backprop.
|
||||
# It's not necessary to do this on every batch: do it only some of the time,
|
||||
# to save a little time.
|
||||
if training and random.random() < prob:
|
||||
return LimitParamValue.apply(x, min, max)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def _no_op(x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
# a no-op function that will have a node in the autograd graph,
|
||||
# to avoid certain bugs relating to backward hooks
|
||||
return x.chunk(1, dim=-1)[0]
|
||||
|
||||
|
||||
class Identity(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
# Dropout2 is just like normal dropout, except it supports schedules on the dropout rates.
|
||||
class Dropout2(nn.Module):
|
||||
def __init__(self, p: FloatLike):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
|
||||
|
||||
|
||||
class MulForDropout3(torch.autograd.Function):
|
||||
# returns (x * y * alpha) where alpha is a float and y doesn't require
|
||||
# grad and is zero-or-one.
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, y, alpha):
|
||||
assert not y.requires_grad
|
||||
ans = x * y * alpha
|
||||
ctx.save_for_backward(ans)
|
||||
ctx.alpha = alpha
|
||||
return ans
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, ans_grad):
|
||||
(ans,) = ctx.saved_tensors
|
||||
x_grad = ctx.alpha * ans_grad * (ans != 0)
|
||||
return x_grad, None, None
|
||||
|
||||
|
||||
# Dropout3 is just like normal dropout, except it supports schedules on the dropout rates,
|
||||
# and it lets you choose one dimension to share the dropout mask over
|
||||
class Dropout3(nn.Module):
|
||||
def __init__(self, p: FloatLike, shared_dim: int):
|
||||
super().__init__()
|
||||
self.p = p
|
||||
self.shared_dim = shared_dim
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
p = float(self.p)
|
||||
if not self.training or p == 0:
|
||||
return _no_op(x)
|
||||
scale = 1.0 / (1 - p)
|
||||
rand_shape = list(x.shape)
|
||||
rand_shape[self.shared_dim] = 1
|
||||
mask = torch.rand(*rand_shape, device=x.device) > p
|
||||
ans = MulForDropout3.apply(x, mask, scale)
|
||||
return ans
|
||||
|
||||
|
||||
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
||||
if num_channels <= x.shape[-1]:
|
||||
return x[..., :num_channels]
|
||||
else:
|
||||
shape = list(x.shape)
|
||||
shape[-1] = num_channels - shape[-1]
|
||||
zeros = torch.zeros(shape, dtype=x.dtype, device=x.device)
|
||||
return torch.cat((x, zeros), dim=-1)
|
||||
|
||||
|
||||
def _test_softmax():
|
||||
a = torch.randn(2, 10, dtype=torch.float64)
|
||||
b = a.clone()
|
||||
a.requires_grad = True
|
||||
b.requires_grad = True
|
||||
a.softmax(dim=1)[:, 0].sum().backward()
|
||||
print("a grad = ", a.grad)
|
||||
softmax(b, dim=1)[:, 0].sum().backward()
|
||||
print("b grad = ", b.grad)
|
||||
assert torch.allclose(a.grad, b.grad)
|
||||
|
||||
|
||||
def _test_piecewise_linear():
|
||||
p = PiecewiseLinear((0, 10.0))
|
||||
for x in [-100, 0, 100]:
|
||||
assert p(x) == 10.0
|
||||
p = PiecewiseLinear((0, 10.0), (1, 0.0))
|
||||
for x, y in [(-100, 10.0), (0, 10.0), (0.5, 5.0), (1, 0.0), (2, 0.0)]:
|
||||
print("x, y = ", x, y)
|
||||
assert p(x) == y, (x, p(x), y)
|
||||
|
||||
q = PiecewiseLinear((0.5, 15.0), (0.6, 1.0))
|
||||
x_vals = [-1.0, 0.0, 0.1, 0.2, 0.5, 0.6, 0.7, 0.9, 1.0, 2.0]
|
||||
pq = p.max(q)
|
||||
for x in x_vals:
|
||||
y1 = max(p(x), q(x))
|
||||
y2 = pq(x)
|
||||
assert abs(y1 - y2) < 0.001
|
||||
pq = p.min(q)
|
||||
for x in x_vals:
|
||||
y1 = min(p(x), q(x))
|
||||
y2 = pq(x)
|
||||
assert abs(y1 - y2) < 0.001
|
||||
pq = p + q
|
||||
for x in x_vals:
|
||||
y1 = p(x) + q(x)
|
||||
y2 = pq(x)
|
||||
assert abs(y1 - y2) < 0.001
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_piecewise_linear()
|
||||
_test_softmax()
|
93
egs/librispeech/ASR/zipformer_adam/scaling_converter.py
Normal file
93
egs/librispeech/ASR/zipformer_adam/scaling_converter.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file replaces various modules in a model.
|
||||
Specifically, BasicNorm is replaced by a module with `exp` removed.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import (
|
||||
Dropout3,
|
||||
ScaleGrad,
|
||||
)
|
||||
from zipformer import CompactRelPositionalEncoding
|
||||
|
||||
|
||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||
# get_submodule was added to nn.Module at v1.9.0
|
||||
def get_submodule(model, target):
|
||||
if target == "":
|
||||
return model
|
||||
atoms: List[str] = target.split(".")
|
||||
mod: torch.nn.Module = model
|
||||
for item in atoms:
|
||||
if not hasattr(mod, item):
|
||||
raise AttributeError(
|
||||
mod._get_name() + " has no " "attribute `" + item + "`"
|
||||
)
|
||||
mod = getattr(mod, item)
|
||||
if not isinstance(mod, torch.nn.Module):
|
||||
raise AttributeError("`" + item + "` is not " "an nn.Module")
|
||||
return mod
|
||||
|
||||
|
||||
def convert_scaled_to_non_scaled(
|
||||
model: nn.Module,
|
||||
inplace: bool = False,
|
||||
is_pnnx: bool = False,
|
||||
is_onnx: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
The model to be converted.
|
||||
inplace:
|
||||
If True, the input model is modified inplace.
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
is_pnnx:
|
||||
True if we are going to export the model for PNNX.
|
||||
is_onnx:
|
||||
True if we are going to export the model for ONNX.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
d = {}
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, (Dropout3, ScaleGrad)):
|
||||
d[name] = nn.Identity()
|
||||
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
||||
# We want to recreate the positional encoding vector when
|
||||
# the input changes, so we have to use torch.jit.script()
|
||||
# to replace torch.jit.trace()
|
||||
d[name] = torch.jit.script(m)
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
parent, child = k.rsplit(".", maxsplit=1)
|
||||
setattr(get_submodule(model, parent), child, v)
|
||||
else:
|
||||
setattr(model, k, v)
|
||||
|
||||
return model
|
357
egs/librispeech/ASR/zipformer_adam/subsampling.py
Normal file
357
egs/librispeech/ASR/zipformer_adam/subsampling.py
Normal file
@ -0,0 +1,357 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from scaling import (
|
||||
BiasNorm,
|
||||
Dropout3,
|
||||
FloatLike,
|
||||
Optional,
|
||||
ScaleGrad,
|
||||
ScheduledFloat,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class ConvNeXt(nn.Module):
|
||||
"""
|
||||
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
hidden_ratio: int = 3,
|
||||
kernel_size: Tuple[int, int] = (7, 7),
|
||||
layerdrop_rate: FloatLike = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
||||
hidden_channels = channels * hidden_ratio
|
||||
if layerdrop_rate is None:
|
||||
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
||||
self.layerdrop_rate = layerdrop_rate
|
||||
|
||||
self.depthwise_conv = nn.Conv2d(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
groups=channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
self.pointwise_conv1 = nn.Conv2d(
|
||||
in_channels=channels, out_channels=hidden_channels, kernel_size=1
|
||||
)
|
||||
|
||||
self.activation = nn.SiLU(inplace=True)
|
||||
self.pointwise_conv2 = nn.Conv2d(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return self.forward_internal(x)
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
|
||||
if layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = (
|
||||
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
> layerdrop_rate
|
||||
)
|
||||
else:
|
||||
mask = None
|
||||
# turns out this caching idea does not work with --world-size > 1
|
||||
# return caching_eval(self.forward_internal, x, mask)
|
||||
return self.forward_internal(x, mask)
|
||||
|
||||
def forward_internal(
|
||||
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
|
||||
) -> Tensor:
|
||||
"""
|
||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||
|
||||
The returned value has the same shape as x.
|
||||
"""
|
||||
bypass = x
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
if layer_skip_mask is not None:
|
||||
x = x * layer_skip_mask
|
||||
|
||||
x = bypass + x
|
||||
|
||||
if x.requires_grad:
|
||||
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||
x = x.transpose(1, 3) # (N, C, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
cached_left_pad: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
|
||||
|
||||
Returns:
|
||||
- The returned value has the same shape as x.
|
||||
- Updated cached_left_pad.
|
||||
"""
|
||||
padding = self.padding
|
||||
|
||||
# The length without right padding for depth-wise conv
|
||||
T = x.size(2) - padding[0]
|
||||
|
||||
bypass = x[:, :, :T, :]
|
||||
|
||||
# Pad left side
|
||||
assert cached_left_pad.size(2) == padding[0], (
|
||||
cached_left_pad.size(2),
|
||||
padding[0],
|
||||
)
|
||||
x = torch.cat([cached_left_pad, x], dim=2)
|
||||
# Update cached left padding
|
||||
cached_left_pad = x[:, :, T : padding[0] + T, :]
|
||||
|
||||
# depthwise_conv
|
||||
x = torch.nn.functional.conv2d(
|
||||
x,
|
||||
weight=self.depthwise_conv.weight,
|
||||
bias=self.depthwise_conv.bias,
|
||||
padding=(0, padding[1]),
|
||||
groups=self.depthwise_conv.groups,
|
||||
)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = self.activation(x)
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
x = bypass + x
|
||||
return x, cached_left_pad
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/2 length).
|
||||
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = (T-3)//2 - 2 == (T-7)//2
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
dropout: FloatLike = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
in_channels:
|
||||
Number of channels in. The input shape is (N, T, in_channels).
|
||||
Caution: It requires: T >=7, in_channels >=7
|
||||
out_channels
|
||||
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
||||
layer1_channels:
|
||||
Number of channels in layer1
|
||||
layer1_channels:
|
||||
Number of channels in layer2
|
||||
bottleneck:
|
||||
bottleneck dimension for 1d squeeze-excite
|
||||
"""
|
||||
assert in_channels >= 7
|
||||
super().__init__()
|
||||
|
||||
# The ScaleGrad module is there to prevent the gradients
|
||||
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
||||
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
||||
# training. (The second one is necessary to stop its bias from getting
|
||||
# a too-large gradient).
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1,
|
||||
out_channels=layer1_channels,
|
||||
kernel_size=3,
|
||||
padding=(0, 1), # (time, freq)
|
||||
),
|
||||
ScaleGrad(0.2),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer1_channels,
|
||||
out_channels=layer2_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=layer2_channels,
|
||||
out_channels=layer3_channels,
|
||||
kernel_size=3,
|
||||
stride=(1, 2), # (time, freq)
|
||||
),
|
||||
nn.SiLU()
|
||||
)
|
||||
|
||||
# just one convnext layer
|
||||
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||
|
||||
# (in_channels-3)//4
|
||||
self.out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||
self.layer3_channels = layer3_channels
|
||||
|
||||
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
|
||||
|
||||
self.out_norm = BiasNorm(out_channels)
|
||||
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
|
||||
Returns:
|
||||
- a tensor of shape (N, (T-7)//2, odim)
|
||||
- output lengths, of shape (batch_size,)
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
||||
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
||||
# gradients.
|
||||
x = self.conv(x)
|
||||
x = self.convnext(x)
|
||||
|
||||
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, (T-7)//2, out_width * layer3_channels))
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, (T-7)//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
x_lens = (x_lens - 7) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
x_lens = (x_lens - 7) // 2
|
||||
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
|
||||
|
||||
return x, x_lens
|
||||
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
cached_left_pad: Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, idim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
|
||||
Returns:
|
||||
- a tensor of shape (N, (T-7)//2, odim)
|
||||
- output lengths, of shape (batch_size,)
|
||||
- updated cache
|
||||
"""
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
|
||||
# T' = (T-7)//2
|
||||
x = self.conv(x)
|
||||
|
||||
# T' = (T-7)//2-3
|
||||
x, cached_left_pad = self.convnext.streaming_forward(
|
||||
x, cached_left_pad=cached_left_pad
|
||||
)
|
||||
|
||||
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||
# now x: (N, T', out_width * layer3_channels))
|
||||
|
||||
x = self.out(x)
|
||||
# Now x is of shape (N, T', odim)
|
||||
x = self.out_norm(x)
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
assert self.convnext.padding[0] == 3
|
||||
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
||||
x_lens = (x_lens - 7) // 2 - 3
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
||||
assert self.convnext.padding[0] == 3
|
||||
x_lens = (x_lens - 7) // 2 - 3
|
||||
|
||||
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
|
||||
|
||||
return x, x_lens, cached_left_pad
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_states(
|
||||
self,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> Tensor:
|
||||
"""Get initial states for Conv2dSubsampling module.
|
||||
It is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
"""
|
||||
left_pad = self.convnext.padding[0]
|
||||
freq = self.out_width
|
||||
channels = self.layer3_channels
|
||||
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
||||
device
|
||||
)
|
||||
|
||||
return cached_embed_left_pad
|
1479
egs/librispeech/ASR/zipformer_adam/train.py
Executable file
1479
egs/librispeech/ASR/zipformer_adam/train.py
Executable file
File diff suppressed because it is too large
Load Diff
2250
egs/librispeech/ASR/zipformer_adam/zipformer.py
Normal file
2250
egs/librispeech/ASR/zipformer_adam/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user