Gradient filter for training lstm model (#564)

* init files

* add gradient filter module

* refact getting median value

* add cutoff for grad filter

* delete comments

* apply gradient filter in LSTM module, to filter both input and params

* fix typing and refactor

* filter with soft mask

* rename lstm_transducer_stateless2 to lstm_transducer_stateless3

* fix typos, and update RESULTS.md

* minor fix

* fix return typing

* fix typo
This commit is contained in:
Zengwei Yao 2022-09-29 11:15:43 +08:00 committed by GitHub
parent 923b60a7c6
commit f3ad32777a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 5448 additions and 28 deletions

View File

@ -9,7 +9,7 @@ per-file-ignores =
egs/*/ASR/pruned_transducer_stateless*/*.py: E501, egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/*/ASR/*/optim.py: E501, egs/*/ASR/*/optim.py: E501,
egs/*/ASR/*/scaling.py: E501, egs/*/ASR/*/scaling.py: E501,
egs/librispeech/ASR/lstm_transducer_stateless/*.py: E501, E203 egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203 egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conformer_ctc2/*py: E501, egs/librispeech/ASR/conformer_ctc2/*py: E501,
egs/librispeech/ASR/RESULTS.md: E999, egs/librispeech/ASR/RESULTS.md: E999,

View File

@ -1,12 +1,100 @@
## Results ## Results
#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + multi-dataset) ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter)
[lstm_transducer_stateless2](./lstm_transducer_stateless2) #### [lstm_transducer_stateless3](./lstm_transducer_stateless3)
It implements LSTM model with mechanisms in reworked model for streaming ASR.
Gradient filter is applied inside each lstm module to stabilize the training.
See <https://github.com/k2-fsa/icefall/pull/564> for more details.
##### training on full librispeech
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
The WERs are:
| | test-clean | test-other | comment | decoding mode |
|-------------------------------------|------------|------------|----------------------|----------------------|
| greedy search (max sym per frame 1) | 3.66 | 9.51 | --epoch 40 --avg 15 | simulated streaming |
| greedy search (max sym per frame 1) | 3.66 | 9.48 | --epoch 40 --avg 15 | streaming |
| fast beam search | 3.55 | 9.33 | --epoch 40 --avg 15 | simulated streaming |
| fast beam search | 3.57 | 9.25 | --epoch 40 --avg 15 | streaming |
| modified beam search | 3.55 | 9.28 | --epoch 40 --avg 15 | simulated streaming |
| modified beam search | 3.54 | 9.25 | --epoch 40 --avg 15 | streaming |
Note: `simulated streaming` indicates feeding full utterance during decoding, while `streaming` indicates feeding certain number of frames at each time.
The training command is:
```bash
./lstm_transducer_stateless3/train.py \
--world-size 4 \
--num-epochs 40 \
--start-epoch 1 \
--exp-dir lstm_transducer_stateless3/exp \
--full-libri 1 \
--max-duration 500 \
--master-port 12325 \
--num-encoder-layers 12 \
--grad-norm-threshold 25.0 \
--rnn-hidden-size 1024
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/caNPyr5lT8qAl9qKsXEeEQ/>
The simulated streaming decoding command using greedy search, fast beam search, and modified beam search is:
```bash
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 15 \
--exp-dir lstm_transducer_stateless3/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $decoding_method \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
```
The streaming decoding command using greedy search, fast beam search, and modified beam search is:
```bash
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 15 \
--exp-dir lstm_transducer_stateless3/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $decoding_method \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless3-2022-09-28>
### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + multi-dataset)
#### [lstm_transducer_stateless2](./lstm_transducer_stateless2)
See <https://github.com/k2-fsa/icefall/pull/558> for more details. See <https://github.com/k2-fsa/icefall/pull/558> for more details.
The WERs are: The WERs are:
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
@ -18,6 +106,7 @@ The WERs are:
| modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 |
| fast_beam_search | 2.77 | 7.29 | --iter 472000 --avg 18 | | fast_beam_search | 2.77 | 7.29 | --iter 472000 --avg 18 |
The training command is: The training command is:
```bash ```bash
@ -70,15 +159,16 @@ Pretrained models, training logs, decoding logs, and decoding results
are available at are available at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03> <https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03>
#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
[lstm_transducer_stateless](./lstm_transducer_stateless) ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
#### [lstm_transducer_stateless](./lstm_transducer_stateless)
It implements LSTM model with mechanisms in reworked model for streaming ASR. It implements LSTM model with mechanisms in reworked model for streaming ASR.
See <https://github.com/k2-fsa/icefall/pull/479> for more details. See <https://github.com/k2-fsa/icefall/pull/479> for more details.
#### training on full librispeech ##### training on full librispeech
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496. This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
@ -165,7 +255,7 @@ It is modified from [torchaudio](https://github.com/pytorch/audio).
See <https://github.com/k2-fsa/icefall/pull/440> for more details. See <https://github.com/k2-fsa/icefall/pull/440> for more details.
#### With lower latency setup, training on full librispeech ##### With lower latency setup, training on full librispeech
In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively. In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
@ -316,7 +406,7 @@ Pretrained models, training logs, decoding logs, and decoding results
are available at are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05> <https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05>
#### With higher latency setup, training on full librispeech ##### With higher latency setup, training on full librispeech
In this model, the lengths of chunk and right context are 64 frames (i.e., 0.64s) and 16 frames (i.e., 0.16s), respectively. In this model, the lengths of chunk and right context are 64 frames (i.e., 0.64s) and 16 frames (i.e., 0.16s), respectively.
@ -851,14 +941,14 @@ Pre-trained models, training and decoding logs, and decoding results are availab
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T) ### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)
[conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless) #### [conv_emformer_transducer_stateless](./conv_emformer_transducer_stateless)
It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module for streaming ASR. It implements [Emformer](https://arxiv.org/abs/2010.10759) augmented with convolution module for streaming ASR.
It is modified from [torchaudio](https://github.com/pytorch/audio). It is modified from [torchaudio](https://github.com/pytorch/audio).
See <https://github.com/k2-fsa/icefall/pull/389> for more details. See <https://github.com/k2-fsa/icefall/pull/389> for more details.
#### Training on full librispeech ##### Training on full librispeech
In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively. In this model, the lengths of chunk and right context are 32 frames (i.e., 0.32s) and 8 frames (i.e., 0.08s), respectively.
@ -1011,7 +1101,7 @@ are available at
### LibriSpeech BPE training results (Pruned Stateless Emformer RNN-T) ### LibriSpeech BPE training results (Pruned Stateless Emformer RNN-T)
[pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2) #### [pruned_stateless_emformer_rnnt2](./pruned_stateless_emformer_rnnt2)
Use <https://github.com/k2-fsa/icefall/pull/390>. Use <https://github.com/k2-fsa/icefall/pull/390>.
@ -1079,7 +1169,7 @@ results at:
### LibriSpeech BPE training results (Pruned Stateless Transducer 5) ### LibriSpeech BPE training results (Pruned Stateless Transducer 5)
[pruned_transducer_stateless5](./pruned_transducer_stateless5) #### [pruned_transducer_stateless5](./pruned_transducer_stateless5)
Same as `Pruned Stateless Transducer 2` but with more layers. Same as `Pruned Stateless Transducer 2` but with more layers.
@ -1092,7 +1182,7 @@ The notations `large` and `medium` below are from the [Conformer](https://arxiv.
paper, where the large model has about 118 M parameters and the medium model paper, where the large model has about 118 M parameters and the medium model
has 30.8 M parameters. has 30.8 M parameters.
#### Large ##### Large
Number of model parameters 118129516 (i.e, 118.13 M). Number of model parameters 118129516 (i.e, 118.13 M).
@ -1152,7 +1242,7 @@ results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07> <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>
#### Medium ##### Medium
Number of model parameters 30896748 (i.e, 30.9 M). Number of model parameters 30896748 (i.e, 30.9 M).
@ -1212,7 +1302,7 @@ results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-07-07> <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-M-2022-07-07>
#### Baseline-2 ##### Baseline-2
It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more It has 88.98 M parameters. Compared to the model in pruned_transducer_stateless2, its has more
layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim). layers (24 v.s 12) but a narrower model (1536 feedforward dim and 384 encoder dim vs 2048 feed forward dim and 512 encoder dim).
@ -1273,13 +1363,13 @@ results at:
### LibriSpeech BPE training results (Pruned Stateless Transducer 4) ### LibriSpeech BPE training results (Pruned Stateless Transducer 4)
[pruned_transducer_stateless4](./pruned_transducer_stateless4) #### [pruned_transducer_stateless4](./pruned_transducer_stateless4)
This version saves averaged model during training, and decodes with averaged model. This version saves averaged model during training, and decodes with averaged model.
See <https://github.com/k2-fsa/icefall/issues/337> for details about the idea of model averaging. See <https://github.com/k2-fsa/icefall/issues/337> for details about the idea of model averaging.
#### Training on full librispeech ##### Training on full librispeech
See <https://github.com/k2-fsa/icefall/pull/344> See <https://github.com/k2-fsa/icefall/pull/344>
@ -1355,7 +1445,7 @@ Pretrained models, training logs, decoding logs, and decoding results
are available at are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03> <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>
#### Training on train-clean-100 ##### Training on train-clean-100
See <https://github.com/k2-fsa/icefall/pull/344> See <https://github.com/k2-fsa/icefall/pull/344>
@ -1392,7 +1482,7 @@ The tensorboard log can be found at
### LibriSpeech BPE training results (Pruned Stateless Transducer 3, 2022-04-29) ### LibriSpeech BPE training results (Pruned Stateless Transducer 3, 2022-04-29)
[pruned_transducer_stateless3](./pruned_transducer_stateless3) #### [pruned_transducer_stateless3](./pruned_transducer_stateless3)
Same as `Pruned Stateless Transducer 2` but using the XL subset from Same as `Pruned Stateless Transducer 2` but using the XL subset from
[GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data. [GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data.
@ -1606,10 +1696,10 @@ can be found at
### LibriSpeech BPE training results (Pruned Transducer 2) ### LibriSpeech BPE training results (Pruned Transducer 2)
[pruned_transducer_stateless2](./pruned_transducer_stateless2) #### [pruned_transducer_stateless2](./pruned_transducer_stateless2)
This is with a reworked version of the conformer encoder, with many changes. This is with a reworked version of the conformer encoder, with many changes.
#### Training on fulll librispeech ##### Training on full librispeech
Using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`. Using commit `34aad74a2c849542dd5f6359c9e6b527e8782fd6`.
See <https://github.com/k2-fsa/icefall/pull/288> See <https://github.com/k2-fsa/icefall/pull/288>
@ -1658,7 +1748,7 @@ can be found at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29> <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>
#### Training on train-clean-100: ##### Training on train-clean-100:
Trained with 1 job: Trained with 1 job:
``` ```

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/__init__.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,818 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./lstm_transducer_stateless2/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./lstm_transducer_stateless3/decode.py \
--epoch 40 \
--avg 20 \
--exp-dir ./lstm_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
# tail padding here to alleviate the tail deletion problem
num_tail_padded_frames = 35
feature = torch.nn.functional.pad(
feature,
(0, 0, 0, num_tail_padded_frames),
mode="constant",
value=LOG_EPS,
)
feature_lens += num_tail_padded_frames
encoder_out, encoder_out_lens, _ = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,388 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.trace()
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 20 \
--jit-trace 1
It will generate 3 files: `encoder_jit_trace.pt`,
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
(2) Export `model.state_dict()`
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 20
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `lstm_transducer_stateless3/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./lstm_transducer_stateless3/decode.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
# You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit-trace",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace.
It will generate 3 files:
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
Check ./jit_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
states = encoder_model.get_init_states()
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
else:
logging.info("Not using torchscript")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,322 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, either exported by `torch.jit.trace()`
or by `torch.jit.script()`, and uses them to decode waves.
You can use the following command to get the exported models:
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 40 \
--avg 15 \
--jit-trace 1
Usage of this script:
./lstm_transducer_stateless3/jit_pretrained.py \
--encoder-model-filename ./lstm_transducer_stateless3/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless3/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless3/exp/joiner_jit_trace.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
encoder = torch.jit.load(args.encoder_model_filename)
decoder = torch.jit.load(args.decoder_model_filename)
joiner = torch.jit.load(args.joiner_model_filename)
encoder.eval()
decoder.eval()
joiner.eval()
encoder.to(device)
decoder.to(device)
joiner.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
states = encoder.get_init_states(batch_size=features.size(0), device=device)
encoder_out, encoder_out_lens, _ = encoder(
x=features,
x_lens=feature_lengths,
states=states,
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,860 @@
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv2d,
ScaledLinear,
ScaledLSTM,
)
from torch import nn
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: Tuple[torch.Tensor, torch.Tensor]
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Unstack the lstm states corresponding to a batch of utterances into a list
of states, where the i-th entry is the state from the i-th utterance.
Args:
states:
A tuple of 2 elements.
``states[0]`` is the lstm hidden states, of a batch of utterance.
``states[1]`` is the lstm cell states, of a batch of utterances.
Returns:
A list of states.
``states[i]`` is a tuple of 2 elememts of i-th utterance.
``states[i][0]`` is the lstm hidden states of i-th utterance.
``states[i][1]`` is the lstm cell states of i-th utterance.
"""
hidden_states, cell_states = states
list_hidden_states = hidden_states.unbind(dim=1)
list_cell_states = cell_states.unbind(dim=1)
ans = [
(h.unsqueeze(1), c.unsqueeze(1))
for (h, c) in zip(list_hidden_states, list_cell_states)
]
return ans
def stack_states(
states_list: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Stack list of lstm states corresponding to separate utterances into a single
lstm state so that it can be used as an input for lstm when those utterances
are formed into a batch.
Args:
state_list:
Each element in state_list corresponds to the lstm state for a single
utterance.
``states[i]`` is a tuple of 2 elememts of i-th utterance.
``states[i][0]`` is the lstm hidden states of i-th utterance.
``states[i][1]`` is the lstm cell states of i-th utterance.
Returns:
A new state corresponding to a batch of utterances.
It is a tuple of 2 elements.
``states[0]`` is the lstm hidden states, of a batch of utterance.
``states[1]`` is the lstm cell states, of a batch of utterances.
"""
hidden_states = torch.cat([s[0] for s in states_list], dim=1)
cell_states = torch.cat([s[1] for s in states_list], dim=1)
ans = (hidden_states, cell_states)
return ans
class RNN(EncoderInterface):
"""
Args:
num_features (int):
Number of input features.
subsampling_factor (int):
Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
d_model (int):
Output dimension (default=512).
dim_feedforward (int):
Feedforward dimension (default=2048).
rnn_hidden_size (int):
Hidden dimension for lstm layers (default=1024).
grad_norm_threshold:
For each sequence element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
num_encoder_layers (int):
Number of encoder layers (default=12).
dropout (float):
Dropout rate (default=0.1).
layer_dropout (float):
Dropout value for model-level warmup (default=0.075).
aux_layer_period (int):
Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
"""
def __init__(
self,
num_features: int,
subsampling_factor: int = 4,
d_model: int = 512,
dim_feedforward: int = 2048,
rnn_hidden_size: int = 1024,
grad_norm_threshold: float = 10.0,
num_encoder_layers: int = 12,
dropout: float = 0.1,
layer_dropout: float = 0.075,
aux_layer_period: int = 0,
) -> None:
super(RNN, self).__init__()
self.num_features = num_features
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
encoder_layer = RNNEncoderLayer(
d_model=d_model,
dim_feedforward=dim_feedforward,
rnn_hidden_size=rnn_hidden_size,
grad_norm_threshold=grad_norm_threshold,
dropout=dropout,
layer_dropout=layer_dropout,
)
self.encoder = RNNEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(
range(
num_encoder_layers // 3,
num_encoder_layers - 1,
aux_layer_period,
)
)
if aux_layer_period > 0
else None,
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C), where N is the batch size,
T is the sequence length, C is the feature dimension.
x_lens:
A tensor of shape (N,), containing the number of frames in `x`
before padding.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
A tuple of 3 tensors:
- embeddings: its shape is (N, T', d_model), where T' is the output
sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding.
- updated states, whose shape is the same as the input states.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
if states is None:
x = self.encoder(x, warmup=warmup)[0]
# torch.jit.trace requires returned types to be the same as annotated # noqa
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_encoder_layers,
x.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_encoder_layers,
x.size(1),
self.rnn_hidden_size,
)
x, new_states = self.encoder(x, states)
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return x, lengths, new_states
@torch.jit.export
def get_init_states(
self, batch_size: int = 1, device: torch.device = torch.device("cpu")
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get model initial states."""
# for rnn hidden states
hidden_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.d_model), device=device
)
cell_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.rnn_hidden_size),
device=device,
)
return (hidden_states, cell_states)
class RNNEncoderLayer(nn.Module):
"""
RNNEncoderLayer is made up of lstm and feedforward networks.
For stable training, in each lstm module, gradient filter
is applied to filter out extremely large elements in batch gradients
and also the module parameters with soft masks.
Args:
d_model:
The number of expected features in the input (required).
dim_feedforward:
The dimension of feedforward network model (default=2048).
rnn_hidden_size:
The hidden dimension of rnn layer.
grad_norm_threshold:
For each sequence element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
dropout:
The dropout value (default=0.1).
layer_dropout:
The dropout value for model-level warmup (default=0.075).
"""
def __init__(
self,
d_model: int,
dim_feedforward: int,
rnn_hidden_size: int,
grad_norm_threshold: float = 10.0,
dropout: float = 0.1,
layer_dropout: float = 0.075,
) -> None:
super(RNNEncoderLayer, self).__init__()
self.layer_dropout = layer_dropout
self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model)
self.lstm = ScaledLSTM(
input_size=d_model,
hidden_size=rnn_hidden_size,
proj_size=d_model if rnn_hidden_size > d_model else 0,
num_layers=1,
dropout=0.0,
grad_norm_threshold=grad_norm_threshold,
)
self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (1, N, d_model);
states[1] is the cell states of all layers,
with shape of (1, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
src_orig = src
warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
if self.training:
alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else:
alpha = 1.0
# lstm module
if states is None:
src_lstm = self.lstm(src)[0]
# torch.jit.trace requires returned types be the same as annotated
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (1, src.size(1), self.d_model)
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src, new_states
class RNNEncoder(nn.Module):
"""
RNNEncoder is a stack of N encoder layers.
Args:
encoder_layer:
An instance of the RNNEncoderLayer() class (required).
num_layers:
The number of sub-encoder-layers in the encoder (required).
"""
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
aux_layers: Optional[List[int]] = None,
) -> None:
super(RNNEncoder, self).__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.d_model = encoder_layer.d_model
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
self.aux_layers: List[int] = []
self.combiner: Optional[nn.Module] = None
if aux_layers is not None:
assert len(set(aux_layers)) == len(aux_layers)
assert num_layers - 1 not in aux_layers
self.aux_layers = aux_layers + [num_layers - 1]
self.combiner = RandomCombine(
num_inputs=len(self.aux_layers),
final_weight=0.5,
pure_prob=0.333,
stddev=2.0,
)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer in turn.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
if states is not None:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_layers,
src.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_layers,
src.size(1),
self.rnn_hidden_size,
)
output = src
outputs = []
new_hidden_states = []
new_cell_states = []
for i, mod in enumerate(self.layers):
if states is None:
output = mod(output, warmup=warmup)[0]
else:
layer_state = (
states[0][i : i + 1, :, :], # h: (1, N, d_model)
states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size)
)
output, (h, c) = mod(output, layer_state)
new_hidden_states.append(h)
new_cell_states.append(c)
if self.combiner is not None and i in self.aux_layers:
outputs.append(output)
if self.combiner is not None:
output = self.combiner(outputs)
if states is None:
new_states = (torch.empty(0), torch.empty(0))
else:
new_states = (
torch.cat(new_hidden_states, dim=0),
torch.cat(new_cell_states, dim=0),
)
return output, new_states
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-3)//2-1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >= 9, in_channels >= 9.
out_channels
Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
"""
assert in_channels >= 9
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=0,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-3)//2-1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
return x
class RandomCombine(nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(
self,
num_inputs: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
) -> None:
"""
Args:
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev
self.final_log_weight = (
torch.tensor(
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
)
.log()
.item()
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training or torch.jit.is_scripting():
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
# The following if causes errors for torch script in torch 1.6.0
# if __name__ == "__main__":
# # for testing only...
# print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)
def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa
nonfinal = torch.randint(
self.num_inputs - 1, (num_frames,), device=device
)
indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(
indexes, num_classes=self.num_inputs
).to(dtype=dtype)
return ans
def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa
)
num_inputs = 3
num_channels = 50
m = RandomCombine(
num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_random_combine_main():
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
feature_dim = 50
c = RNN(num_features=feature_dim, d_model=128)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
if __name__ == "__main__":
feature_dim = 80
m = RNN(
num_features=feature_dim,
d_model=512,
rnn_hidden_size=1024,
dim_feedforward=2048,
num_encoder_layers=12,
)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = m(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)
num_param = sum([p.numel() for p in m.parameters()])
print(f"Number of model parameters: {num_param}")
_test_random_combine_main()

View File

@ -0,0 +1 @@
../lstm_transducer_stateless/model.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./lstm_transducer_stateless3/exp/epoch-xx.pt`.
Note: ./lstm_transducer_stateless3/exp/pretrained.pt is generated by
./lstm_transducer_stateless3/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens, _ = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py

View File

@ -0,0 +1 @@
../lstm_transducer_stateless/stream.py

View File

@ -0,0 +1,968 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method greedy_search \
--use-averaged-model True
(2) modified beam search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
(3) fast beam search
./lstm_transducer_stateless3/streaming_decode.py \
--epoch 40 \
--avg 20 \
--exp-dir lstm_transducer_stateless3/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from lstm import LOG_EPSILON, stack_states, unstack_states
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=40,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless3/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--sampling-rate",
type=float,
default=16000,
help="Sample rate of the audio",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded in parallel",
)
add_model_arguments(parser)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
T = encoder_out.size(1)
encoder_out = model.joiner.encoder_proj(encoder_out)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
beam: int = 4,
):
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, project_input=False
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
streams: List[Stream],
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
streams:
A list of stream objects.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
assert B == len(streams)
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyps[i]
def decode_one_chunk(
model: nn.Module,
streams: List[Stream],
params: AttributeDict,
decoding_graph: Optional[k2.Fsa] = None,
) -> List[int]:
"""
Args:
model:
The Transducer model.
streams:
A list of Stream objects.
params:
It is returned by :func:`get_params`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
A list of indexes indicating the finished streams.
"""
device = next(model.parameters()).device
feature_list = []
feature_len_list = []
state_list = []
num_processed_frames_list = []
for stream in streams:
# We should first get `stream.num_processed_frames`
# before calling `stream.get_feature_chunk()`
# since `stream.num_processed_frames` would be updated
num_processed_frames_list.append(stream.num_processed_frames)
feature = stream.get_feature_chunk()
feature_len = feature.size(0)
feature_list.append(feature)
feature_len_list.append(feature_len)
state_list.append(stream.states)
features = pad_sequence(
feature_list, batch_first=True, padding_value=LOG_EPSILON
).to(device)
feature_lens = torch.tensor(feature_len_list, device=device)
num_processed_frames = torch.tensor(
num_processed_frames_list, device=device
)
# Make sure it has at least 1 frame after subsampling
tail_length = params.subsampling_factor + 5
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPSILON,
)
# Stack states of all streams
states = stack_states(state_list)
encoder_out, encoder_out_lens, states = model.encoder(
x=features,
x_lens=feature_lens,
states=states,
)
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
beam=params.beam_size,
)
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
# The rnnt_decoding_stream for fast_beam_search.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
processed_lens = (
num_processed_frames // params.subsampling_factor
+ encoder_out_lens
)
fast_beam_search_one_best(
model=model,
streams=streams,
encoder_out=encoder_out,
processed_lens=processed_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
# Update cached states of each stream
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
return finished_streams
def create_streaming_feature_extractor() -> Fbank:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return Fbank(opts)
def decode_dataset(
cuts: CutSet,
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
):
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The Transducer model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = next(model.parameters()).device
log_interval = 300
fbank = create_streaming_feature_extractor()
decode_results = []
streams = []
for num, cut in enumerate(cuts):
# Each utterance has a Stream.
stream = Stream(
params=params,
cut_id=cut.id,
decoding_graph=decoding_graph,
device=device,
LOG_EPS=LOG_EPSILON,
)
stream.states = model.encoder.get_init_states(device=device)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
stream.set_feature(feature)
stream.ground_truth = cut.supervisions[0].text
streams.append(stream)
while len(streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
while len(streams) > 0:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
else:
key = f"beam_size_{params.beam_size}"
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-streaming-decode")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
model=model,
params=params,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
torch.manual_seed(20220810)
main()

View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./lstm_transducer_stateless/test_model.py
"""
import os
from pathlib import Path
import torch
from export import (
export_decoder_model_jit_trace,
export_encoder_model_jit_trace,
export_joiner_model_jit_trace,
)
from lstm import stack_states, unstack_states
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = 0
params.exp_dir = Path("exp_test_model")
model = get_transducer_model(params)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
convert_scaled_to_non_scaled(model, inplace=True)
if not os.path.exists(params.exp_dir):
os.path.mkdir(params.exp_dir)
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
print("The model has been successfully exported using jit.trace.")
def test_states_stack_and_unstack():
layer, batch, hidden, cell = 12, 100, 512, 1024
states = (
torch.randn(layer, batch, hidden),
torch.randn(layer, batch, cell),
)
states2 = stack_states(unstack_states(states))
assert torch.allclose(states[0], states2[0])
assert torch.allclose(states[1], states2[1])
def main():
test_model()
test_states_stack_and_unstack()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,257 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./lstm_transducer_stateless/test_scaling_converter.py
"""
import copy
import torch
from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
ScaledLinear,
ScaledLSTM,
)
from scaling_converter import (
convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear,
scaled_lstm_to_lstm,
)
from train import get_params, get_transducer_model
def get_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = -1
model = get_transducer_model(params)
return model
def test_scaled_linear_to_linear():
N = 5
in_features = 10
out_features = 20
for bias in [True, False]:
scaled_linear = ScaledLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
)
linear = scaled_linear_to_linear(scaled_linear)
x = torch.rand(N, in_features)
y1 = scaled_linear(x)
y2 = linear(x)
assert torch.allclose(y1, y2)
jit_scaled_linear = torch.jit.script(scaled_linear)
jit_linear = torch.jit.script(linear)
y3 = jit_scaled_linear(x)
y4 = jit_linear(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv1d_to_conv1d():
in_channels = 3
for bias in [True, False]:
scaled_conv1d = ScaledConv1d(
in_channels,
6,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
x = torch.rand(20, in_channels, 10)
y1 = scaled_conv1d(x)
y2 = conv1d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
jit_conv1d = torch.jit.script(conv1d)
y3 = jit_scaled_conv1d(x)
y4 = jit_conv1d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv2d_to_conv2d():
in_channels = 1
for bias in [True, False]:
scaled_conv2d = ScaledConv2d(
in_channels=in_channels,
out_channels=3,
kernel_size=3,
padding=1,
bias=bias,
)
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
x = torch.rand(20, in_channels, 10, 20)
y1 = scaled_conv2d(x)
y2 = conv2d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
jit_conv2d = torch.jit.script(conv2d)
y3 = jit_scaled_conv2d(x)
y4 = jit_conv2d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)
for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)
def test_scaled_lstm_to_lstm():
input_size = 512
batch_size = 20
for bias in [True, False]:
for hidden_size in [512, 1024]:
scaled_lstm = ScaledLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
proj_size=0 if hidden_size == input_size else input_size,
)
lstm = scaled_lstm_to_lstm(scaled_lstm)
x = torch.rand(200, batch_size, input_size)
h0 = torch.randn(1, batch_size, input_size)
c0 = torch.randn(1, batch_size, hidden_size)
y1, (h1, c1) = scaled_lstm(x, (h0, c0))
y2, (h2, c2) = lstm(x, (h0, c0))
assert torch.allclose(y1, y2)
assert torch.allclose(h1, h2)
assert torch.allclose(c1, c2)
jit_scaled_lstm = torch.jit.trace(lstm, (x, (h0, c0)))
y3, (h3, c3) = jit_scaled_lstm(x, (h0, c0))
assert torch.allclose(y1, y3)
assert torch.allclose(h1, h3)
assert torch.allclose(c1, c3)
def test_convert_scaled_to_non_scaled():
for inplace in [False, True]:
model = get_model()
model.eval()
orig_model = copy.deepcopy(model)
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
model = orig_model
# test encoder
N = 2
T = 100
vocab_size = model.decoder.vocab_size
x = torch.randn(N, T, 80, dtype=torch.float32)
x_lens = torch.full((N,), x.size(1))
e1, e1_lens, _ = model.encoder(x, x_lens)
e2, e2_lens, _ = converted_model.encoder(x, x_lens)
assert torch.all(torch.eq(e1_lens, e2_lens))
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
# test decoder
U = 50
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y)
d2 = model.decoder(y)
assert torch.allclose(d1, d2)
# test simple projection
lm1 = model.simple_lm_proj(d1)
am1 = model.simple_am_proj(e1)
lm2 = converted_model.simple_lm_proj(d2)
am2 = converted_model.simple_am_proj(e2)
assert torch.allclose(lm1, lm2)
assert torch.allclose(am1, am2)
# test joiner
e = torch.rand(2, 3, 4, 512)
d = torch.rand(2, 3, 4, 512)
j1 = model.joiner(e, d)
j2 = converted_model.joiner(e, d)
assert torch.allclose(j1, j2)
@torch.no_grad()
def main():
test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_scaled_lstm_to_lstm()
test_convert_scaled_to_non_scaled()
if __name__ == "__main__":
torch.manual_seed(20220730)
main()

File diff suppressed because it is too large Load Diff

View File

@ -111,6 +111,76 @@ class ActivationBalancerFunction(torch.autograd.Function):
return x_grad - neg_delta_grad, None, None, None, None, None, None return x_grad - neg_delta_grad, None, None, None, None, None, None
class GradientFilterFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
batch_dim: int, # e.g., 1
threshold: float, # e.g., 10.0
*params: Tensor, # module parameters
) -> Tuple[Tensor, ...]:
if x.requires_grad:
if batch_dim < 0:
batch_dim += x.ndim
ctx.batch_dim = batch_dim
ctx.threshold = threshold
return (x,) + params
@staticmethod
def backward(
ctx,
x_grad: Tensor,
*param_grads: Tensor,
) -> Tuple[Tensor, ...]:
eps = 1.0e-20
dim = ctx.batch_dim
norm_dims = [d for d in range(x_grad.ndim) if d != dim]
norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt()
median_norm = norm_of_batch.median()
cutoff = median_norm * ctx.threshold
inv_mask = (cutoff + norm_of_batch) / (cutoff + eps)
mask = 1.0 / (inv_mask + eps)
x_grad = x_grad * mask
avg_mask = 1.0 / (inv_mask.mean() + eps)
param_grads = [avg_mask * g for g in param_grads]
return (x_grad, None, None) + tuple(param_grads)
class GradientFilter(torch.nn.Module):
"""This is used to filter out elements that have extremely large gradients
in batch and the module parameters with soft masks.
Args:
batch_dim (int):
The batch dimension.
threshold (float):
For each element in batch, its gradient will be
filtered out if the gradient norm is larger than
`grad_norm_threshold * median`, where `median` is the median
value of gradient norms of all elememts in batch.
"""
def __init__(self, batch_dim: int = 1, threshold: float = 10.0):
super(GradientFilter, self).__init__()
self.batch_dim = batch_dim
self.threshold = threshold
def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]:
if torch.jit.is_scripting() or is_jit_tracing():
return (x,) + params
else:
return GradientFilterFunction.apply(
x,
self.batch_dim,
self.threshold,
*params,
)
class BasicNorm(torch.nn.Module): class BasicNorm(torch.nn.Module):
""" """
This is intended to be a simpler, and hopefully cheaper, replacement for This is intended to be a simpler, and hopefully cheaper, replacement for
@ -195,7 +265,7 @@ class ScaledLinear(nn.Linear):
*args, *args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
initial_speed: float = 1.0, initial_speed: float = 1.0,
**kwargs **kwargs,
): ):
super(ScaledLinear, self).__init__(*args, **kwargs) super(ScaledLinear, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
@ -242,7 +312,7 @@ class ScaledConv1d(nn.Conv1d):
*args, *args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
initial_speed: float = 1.0, initial_speed: float = 1.0,
**kwargs **kwargs,
): ):
super(ScaledConv1d, self).__init__(*args, **kwargs) super(ScaledConv1d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
@ -314,7 +384,7 @@ class ScaledConv2d(nn.Conv2d):
*args, *args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
initial_speed: float = 1.0, initial_speed: float = 1.0,
**kwargs **kwargs,
): ):
super(ScaledConv2d, self).__init__(*args, **kwargs) super(ScaledConv2d, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log() initial_scale = torch.tensor(initial_scale).log()
@ -389,7 +459,8 @@ class ScaledLSTM(nn.LSTM):
*args, *args,
initial_scale: float = 1.0, initial_scale: float = 1.0,
initial_speed: float = 1.0, initial_speed: float = 1.0,
**kwargs grad_norm_threshold: float = 10.0,
**kwargs,
): ):
if "bidirectional" in kwargs: if "bidirectional" in kwargs:
assert kwargs["bidirectional"] is False assert kwargs["bidirectional"] is False
@ -404,6 +475,10 @@ class ScaledLSTM(nn.LSTM):
setattr(self, scale_name, param) setattr(self, scale_name, param)
self._scales.append(param) self._scales.append(param)
self.grad_filter = GradientFilter(
batch_dim=1, threshold=grad_norm_threshold
)
self._reset_parameters( self._reset_parameters(
initial_speed initial_speed
) # Overrides the reset_parameters in base class ) # Overrides the reset_parameters in base class
@ -513,10 +588,14 @@ class ScaledLSTM(nn.LSTM):
hx = (h_zeros, c_zeros) hx = (h_zeros, c_zeros)
self.check_forward_args(input, hx, None) self.check_forward_args(input, hx, None)
flat_weights = self._get_flat_weights()
input, *flat_weights = self.grad_filter(input, *flat_weights)
result = _VF.lstm( result = _VF.lstm(
input, input,
hx, hx,
self._get_flat_weights(), flat_weights,
self.bias, self.bias,
self.num_layers, self.num_layers,
self.dropout, self.dropout,
@ -891,9 +970,54 @@ def _test_scaled_lstm():
assert c.shape == (1, N, dim_hidden) assert c.shape == (1, N, dim_hidden)
def _test_grad_filter():
threshold = 50.0
time, batch, channel = 200, 5, 128
grad_filter = GradientFilter(batch_dim=1, threshold=threshold)
for i in range(2):
x = torch.randn(time, batch, channel, requires_grad=True)
w = nn.Parameter(torch.ones(5))
b = nn.Parameter(torch.zeros(5))
x_out, w_out, b_out = grad_filter(x, w, b)
w_out_grad = torch.randn_like(w)
b_out_grad = torch.randn_like(b)
x_out_grad = torch.rand_like(x)
if i % 2 == 1:
# The gradient norm of the first element must be larger than
# `threshold * median`, where `median` is the median value
# of gradient norms of all elements in batch.
x_out_grad[:, 0, :] = torch.full((time, channel), threshold)
torch.autograd.backward(
[x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad]
)
print(
"_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa
i % 2 == 1,
)
print(
"_test_grad_filter: x_out_grad norm = ",
(x_out_grad ** 2).mean(dim=(0, 2)).sqrt(),
)
print(
"_test_grad_filter: x.grad norm = ",
(x.grad ** 2).mean(dim=(0, 2)).sqrt(),
)
print("_test_grad_filter: w_out_grad = ", w_out_grad)
print("_test_grad_filter: w.grad = ", w.grad)
print("_test_grad_filter: b_out_grad = ", b_out_grad)
print("_test_grad_filter: b.grad = ", b.grad)
if __name__ == "__main__": if __name__ == "__main__":
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()
_test_double_swish_deriv() _test_double_swish_deriv()
_test_scaled_lstm() _test_scaled_lstm()
_test_grad_filter()