Merge branch 'master' of https://github.com/k2-fsa/icefall into surt

This commit is contained in:
Desh Raj 2023-02-04 14:53:48 -05:00
commit b3d0d34eb6
20 changed files with 1369 additions and 102 deletions

View File

@ -299,11 +299,11 @@ to run the training part first.
- (1) ``epoch-1.pt``, ``epoch-2.pt``, ..., which are saved at the end
of each epoch. You can pass ``--epoch`` to
``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them.
``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them.
- (2) ``checkpoints-436000.pt``, ``epoch-438000.pt``, ..., which are saved
every ``--save-every-n`` batches. You can pass ``--iter`` to
``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py`` to use them.
``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py`` to use them.
We suggest that you try both types of checkpoints and choose the one
that produces the lowest WERs.
@ -311,7 +311,7 @@ to run the training part first.
.. code-block:: bash
$ cd egs/librispeech/ASR
$ ./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py --help
$ ./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py --help
shows the options for decoding.
@ -320,7 +320,7 @@ The following shows the example using ``epoch-*.pt``:
.. code-block:: bash
for m in greedy_search fast_beam_search modified_beam_search; do
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 30 \
--avg 13 \
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
@ -333,7 +333,7 @@ To test CTC branch, you can use the following command:
.. code-block:: bash
for m in ctc-decoding 1best; do
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 30 \
--avg 13 \
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
@ -367,7 +367,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p
.. hint::
To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``,
To use the generated ``pretrained.pt`` for ``pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``,
you can run:
.. code-block:: bash
@ -376,7 +376,7 @@ It will generate a file ``./pruned_transducer_stateless7_ctc_bs/exp/pretrained.p
ln -s pretrained epoch-9999.pt
And then pass ``--epoch 9999 --avg 1 --use-averaged-model 0`` to
``./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py``.
``./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py``.
To use the exported model with ``./pruned_transducer_stateless7_ctc_bs/pretrained.py``, you
can run:
@ -447,7 +447,8 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models
by visiting the following links:
- `<https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2022-12-14>`_
- trained on LibriSpeech 100h: `<https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2022-12-14>`_
- trained on LibriSpeech 960h: `<https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29>`_
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
for the details of the above pretrained models

View File

@ -1,7 +1,7 @@
# Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell/index.html>
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/aishell/index.html>
for how to run models in this recipe.

View File

@ -1,6 +1,6 @@
# Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech/index.html> for how to run models in this recipe.
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/librispeech/index.html> for how to run models in this recipe.
[./RESULTS.md](./RESULTS.md) contains the latest results.

View File

@ -93,13 +93,13 @@ results at:
Number of model parameters: 69136519, i.e., 69.14 M
| | test-clean | test-other | comment |
|--------------------------|------------|-------------|---------------------|
| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 |
| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 |
| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 |
| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 |
| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 |
| | test-clean | test-other | comment |
| ---------------------- | ---------- | ---------- | ------------------- |
| 1best | 2.54 | 5.65 | --epoch 30 --avg 10 |
| nbest | 2.54 | 5.66 | --epoch 30 --avg 10 |
| nbest-rescoring-LG | 2.49 | 5.42 | --epoch 30 --avg 10 |
| nbest-rescoring-3-gram | 2.52 | 5.62 | --epoch 30 --avg 10 |
| nbest-rescoring-4-gram | 2.5 | 5.51 | --epoch 30 --avg 10 |
The training commands are:
```bash
@ -134,6 +134,97 @@ for m in nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescoring-4-gram;
done
```
### pruned_transducer_stateless7_ctc_bs (zipformer with transducer loss and ctc loss using blank skip)
See https://github.com/k2-fsa/icefall/pull/730 for more details.
[pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)
The tensorboard log can be found at
<https://tensorboard.dev/experiment/rrNZ7l83Qu6RKoD7y49wiA/>
You can find a pretrained model, training logs, decoding logs, and decoding
results at:
<https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29>
Number of model parameters: 76804822, i.e., 76.80 M
Test on 8-card V100 cluster, with 4-card busy and 4-card idle.
#### greedy_search
| model | test-clean | test-other | decoding time(s) | comment |
| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- |
| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.28 | 5.53 | 48.939 | --epoch 30 --avg 13 |
| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.24 | 5.18 | 91.900 | --epoch 30 --avg 8 |
- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip.
- Applying blank skip both on training and decoding is **1.88 times** faster than the model that doesn't apply blank skip without obvious performance loss.
#### modified_beam_search
| model | test-clean | test-other | decoding time(s) | comment |
| ------------------------------------------------------------ | ---------- | ---------- | ---------------- | ------------------- |
| [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) | 2.26 | 5.44 | 80.446 | --epoch 30 --avg 13 |
| [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) | 2.20 | 5.12 | 283.676 | --epoch 30 --avg 8 |
- [pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs) applies blank skip both on training and decoding, and [pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) doesn`t apply blank skip.
- Applying blank skip both on training and decoding is **3.53 times** faster than the model that doesn't apply blank skip without obvious performance loss.
The training commands for the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7_ctc_bs/train.py \
--world-size 4 \
--num-epochs 30 \
--full-libri 1 \
--use-fp16 1 \
--max-duration 750 \
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
--feedforward-dims "1024,1024,2048,2048,1024" \
--ctc-loss-scale 0.2 \
--master-port 12535
```
The decoding commands for the transducer branch of the model using blank skip ([pruned_transducer_stateless7_ctc_bs](./pruned_transducer_stateless7_ctc_bs)) are:
```bash
for m in greedy_search modified_beam_search fast_beam_search; do
for epoch in 30; do
for avg in 15; do
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--feedforward-dims "1024,1024,2048,2048,1024" \
--max-duration 600 \
--decoding-method $m
done
done
done
```
The decoding commands for the transducer branch of the model without blank skip ([pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc)) are:
```bash
for m in greedy_search modified_beam_search fast_beam_search; do
for epoch in 30; do
for avg in 15; do
./pruned_transducer_stateless7_ctc/decode.py \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
--feedforward-dims "1024,1024,2048,2048,1024" \
--max-duration 600 \
--decoding-method $m
done
done
done
```
### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss)

View File

@ -374,21 +374,6 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--filter-uneven-sized-batch",
type=str2bool,
default=True,
help="""Whether to filter uneven-sized minibatch.
For the uneven-sized batch, the total duration after padding would possibly
cause OOM. Hence, for each batch, which is sorted descendingly by length,
we simply drop the last few shortest samples, so that the retained total frames
(after padding) would not exceed `allowed_max_frames`:
`allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
where `max_frames = max_duration * 1000 // frame_shift_ms`.
We set allowed_excess_duration_ratio=0.1.
""",
)
add_model_arguments(parser)
return parser
@ -442,7 +427,6 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"frame_shift_ms": 10.0,
# only used when params.filter_uneven_sized_batch is True
"allowed_excess_duration_ratio": 0.1,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
@ -666,12 +650,16 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
if params.filter_uneven_sized_batch:
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(
max_frames * (1.0 + params.allowed_excess_duration_ratio)
)
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
@ -1055,10 +1043,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds

View File

@ -197,13 +197,13 @@ class Zipformer(EncoderInterface):
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
On e.g. 15% of frames, these masks will zero out all encoder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
a smaller encoder dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_subsampling_factor times.
mask values repeated self.zipformer_downsampling_factors times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
@ -1009,10 +1009,10 @@ class RelPositionMultiheadAttention(nn.Module):
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5, dividing it between the query and key.
in_proj_dim = (
2 * attention_dim
+ attention_dim // 2 # query, key
+ pos_dim * num_heads # value
) # positional encoding query
2 * attention_dim # query, key
+ attention_dim // 2 # value
+ pos_dim * num_heads # positional encoding query
)
self.in_proj = ScaledLinear(
embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25
@ -1509,7 +1509,7 @@ class FeedforwardModule(nn.Module):
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Zipformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.

View File

@ -1072,10 +1072,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds

View File

@ -21,7 +21,7 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -29,7 +29,7 @@ Usage:
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -38,7 +38,7 @@ Usage:
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -47,7 +47,7 @@ Usage:
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -58,7 +58,7 @@ Usage:
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7_ctc/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc/ctc_guide_decode_bs.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
@ -71,7 +71,7 @@ Usage:
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
@ -84,7 +84,7 @@ Usage:
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7_ctc_bs/ctc_guild_decode_bs.py \
./pruned_transducer_stateless7_ctc_bs/ctc_guide_decode_bs.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \

View File

@ -72,14 +72,14 @@ Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
# You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp
"""
import argparse

View File

@ -0,0 +1,665 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to ONNX format
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 13
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- lconv.onnx
- frame_reducer.onnx
Please see ./onnx_pretrained.py for usage of the generated files
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29
# You will find the pre-trained model in icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_ctc_bs/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--onnx",
type=str2bool,
default=True,
help="""If True, --jit is ignored and it exports the model
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- lconv.onnx
- frame_reducer.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(15, 2000, 80, dtype=torch.float32)
x_lens = torch.tensor([2000] * 15, dtype=torch.int64)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(15, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
def export_lconv_onnx(
lconv: nn.Module,
lconv_filename: str,
opset_version: int = 11,
) -> None:
"""Export the lconv to ONNX format.
The exported lconv has two inputs:
- lconv_input: a tensor of shape (N, T, C)
- src_key_padding_mask: a tensor of shape (N, T)
and has one output:
- lconv_out: a tensor of shape (N, T, C)
Args:
lconv:
The lconv to be exported.
lconv_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32)
src_key_padding_mask = torch.zeros(15, 498, dtype=torch.bool)
torch.onnx.export(
lconv,
(lconv_input, src_key_padding_mask),
lconv_filename,
verbose=False,
opset_version=opset_version,
input_names=["lconv_input", "src_key_padding_mask"],
output_names=["lconv_out"],
dynamic_axes={
"lconv_input": {0: "N", 1: "T"},
"src_key_padding_mask": {0: "N", 1: "T"},
"lconv_out": {0: "N", 1: "T"},
},
)
logging.info(f"Saved to {lconv_filename}")
def export_frame_reducer_onnx(
frame_reducer: nn.Module,
frame_reducer_filename: str,
opset_version: int = 11,
) -> None:
"""Export the frame_reducer to ONNX format.
The exported frame_reducer has four inputs:
- x: a tensor of shape (N, T, C)
- x_lens: a tensor of shape (N, T)
- ctc_output: a tensor of shape (N, T, vocab_size)
- blank_id: an int, always 0
and has two outputs:
- x_fr: a tensor of shape (N, T, C)
- x_lens_fr: a tensor of shape (N, T)
Args:
frame_reducer:
The frame_reducer to be exported.
frame_reducer_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.randn(15, 498, 500, dtype=torch.float32)
torch.onnx.export(
frame_reducer,
(x, x_lens, ctc_output),
frame_reducer_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens", "ctc_output"],
output_names=["out", "out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"ctc_output": {0: "N", 1: "T"},
"out": {0: "N", 1: "T"},
"out_lens": {0: "N"},
},
)
logging.info(f"Saved to {frame_reducer_filename}")
def export_ctc_output_onnx(
ctc_output: nn.Module,
ctc_output_filename: str,
opset_version: int = 11,
) -> None:
"""Export the frame_reducer to ONNX format.
The exported frame_reducer has one inputs:
- encoder_out: a tensor of shape (N, T, C)
and has one output:
- ctc_output: a tensor of shape (N, T, vocab_size)
Args:
ctc_output:
The ctc_output to be exported.
ctc_output_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
encoder_out = torch.zeros(15, 498, 384, dtype=torch.float32)
torch.onnx.export(
ctc_output,
(encoder_out),
ctc_output_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["ctc_output"],
dynamic_axes={
"encoder_out": {0: "N", 1: "T"},
"ctc_output": {0: "N", 1: "T"},
},
)
logging.info(f"Saved to {ctc_output_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
opset_version = 13
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
lconv_filename = params.exp_dir / "lconv.onnx"
export_lconv_onnx(
model.lconv,
lconv_filename,
opset_version=opset_version,
)
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
export_frame_reducer_onnx(
model.frame_reducer,
frame_reducer_filename,
opset_version=opset_version,
)
ctc_output_filename = params.exp_dir / "ctc_output.onnx"
export_ctc_output_onnx(
model.ctc_output,
ctc_output_filename,
opset_version=opset_version,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -22,7 +22,8 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from icefall.utils import make_pad_mask
@ -55,25 +56,69 @@ class FrameReducer(nn.Module):
ctc_output:
The CTC output with shape [N, T, vocab_size].
blank_id:
The ID of the blank symbol.
The blank id of ctc_output.
Returns:
x_fr:
out:
The frame reduced encoder output with shape [N, T', C].
x_lens_fr:
out_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x_fr` before padding.
`out` before padding.
"""
N, T, C = x.size()
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
frames_list: List[torch.Tensor] = []
lens_list: List[int] = []
for i in range(x.shape[0]):
frames = x[i][non_blank_mask[i]]
frames_list.append(frames)
lens_list.append(frames.shape[0])
x_fr = pad_sequence(frames_list, batch_first=True)
x_lens_fr = torch.tensor(lens_list).to(device=x.device)
out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens
max_pad_len = pad_lens_list.max()
return x_fr, x_lens_fr
out = F.pad(x, (0, 0, 0, max_pad_len))
valid_pad_mask = ~make_pad_mask(pad_lens_list)
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
out = out[total_valid_mask].reshape(N, -1, C)
return out.to(device=x.device), out_lens.to(device=x.device)
if __name__ == "__main__":
import time
from torch.nn.utils.rnn import pad_sequence
test_times = 10000
frame_reducer = FrameReducer()
# non zero case
x = torch.ones(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32))
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
avg_time = 0
for i in range(test_times):
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)
# all zero case
x = torch.zeros(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32)
avg_time = 0
for i in range(test_times):
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)

View File

@ -62,7 +62,7 @@ class LConv(nn.Module):
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
groups=2 * channels,
bias=bias,
)

View File

@ -0,0 +1,461 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 13
Usage of this script:
./pruned_transducer_stateless7_ctc_bs/onnx_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner.onnx \
--joiner-encoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename ./pruned_transducer_stateless7_ctc_bs/exp/joiner_decoder_proj.onnx \
--lconv-filename ./pruned_transducer_stateless7_ctc_bs/exp/lconv.onnx \
--frame-reducer-filename ./pruned_transducer_stateless7_ctc_bs/exp/frame_reducer.onnx \
--ctc-output-filename ./pruned_transducer_stateless7_ctc_bs/exp/ctc_output.onnx \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from icefall.utils import make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--joiner-encoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner encoder_proj onnx model. ",
)
parser.add_argument(
"--joiner-decoder-proj-model-filename",
type=str,
required=True,
help="Path to the joiner decoder_proj onnx model. ",
)
parser.add_argument(
"--lconv-filename",
type=str,
required=True,
help="Path to the lconv onnx model. ",
)
parser.add_argument(
"--frame-reducer-filename",
type=str,
required=True,
help="Path to the frame reducer onnx model. ",
)
parser.add_argument(
"--ctc-output-filename",
type=str,
required=True,
help="Path to the ctc_output onnx model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: ort.InferenceSession,
joiner: ort.InferenceSession,
joiner_encoder_proj: ort.InferenceSession,
joiner_decoder_proj: ort.InferenceSession,
encoder_out: np.ndarray,
encoder_out_lens: np.ndarray,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
joiner_encoder_proj:
The joiner encoder projection model.
joiner_decoder_proj:
The joiner decoder projection model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
encoder_out = torch.from_numpy(encoder_out)
encoder_out_lens = torch.from_numpy(encoder_out_lens)
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
projected_encoder_out = joiner_encoder_proj.run(
[joiner_encoder_proj.get_outputs()[0].name],
{joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()},
)[0]
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input_nodes = decoder.get_inputs()
decoder_output_nodes = decoder.get_outputs()
joiner_input_nodes = joiner.get_inputs()
joiner_output_nodes = joiner.get_outputs()
decoder_input = torch.tensor(
hyps,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = projected_encoder_out[start:end]
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
projected_decoder_out = projected_decoder_out[:batch_size]
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0].name: np.expand_dims(
np.expand_dims(current_encoder_out, axis=1), axis=1
),
joiner_input_nodes[1]
.name: projected_decoder_out.unsqueeze(1)
.unsqueeze(1)
.numpy(),
},
)[0]
logits = torch.from_numpy(logits).squeeze(1).squeeze(1)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
dtype=torch.int64,
)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
projected_decoder_out = joiner_decoder_proj.run(
[joiner_decoder_proj.get_outputs()[0].name],
{joiner_decoder_proj.get_inputs()[0].name: decoder_out},
)[0]
projected_decoder_out = torch.from_numpy(projected_decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=session_opts,
)
decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=session_opts,
)
joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=session_opts,
)
joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=session_opts,
)
joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=session_opts,
)
lconv = ort.InferenceSession(
args.lconv_filename,
sess_options=session_opts,
)
frame_reducer = ort.InferenceSession(
args.frame_reducer_filename,
sess_options=session_opts,
)
ctc_output = ort.InferenceSession(
args.ctc_output_filename,
sess_options=session_opts,
)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
encoder_input_nodes = encoder.get_inputs()
encoder_out_nodes = encoder.get_outputs()
encoder_out, encoder_out_lens = encoder.run(
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
{
encoder_input_nodes[0].name: features.numpy(),
encoder_input_nodes[1].name: feature_lengths.numpy(),
},
)
ctc_output_input_nodes = ctc_output.get_inputs()
ctc_output_out_nodes = ctc_output.get_outputs()
ctc_out = ctc_output.run(
[ctc_output_out_nodes[0].name],
{
ctc_output_input_nodes[0].name: encoder_out,
},
)[0]
lconv_input_nodes = lconv.get_inputs()
lconv_out_nodes = lconv.get_outputs()
encoder_out = lconv.run(
[lconv_out_nodes[0].name],
{
lconv_input_nodes[0].name: encoder_out,
lconv_input_nodes[1]
.name: make_pad_mask(torch.from_numpy(encoder_out_lens))
.numpy(),
},
)[0]
frame_reducer_input_nodes = frame_reducer.get_inputs()
frame_reducer_out_nodes = frame_reducer.get_outputs()
encoder_out_fr, encoder_out_lens_fr = frame_reducer.run(
[frame_reducer_out_nodes[0].name, frame_reducer_out_nodes[1].name],
{
frame_reducer_input_nodes[0].name: encoder_out,
frame_reducer_input_nodes[1].name: encoder_out_lens,
frame_reducer_input_nodes[2].name: ctc_out,
},
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
joiner_encoder_proj=joiner_encoder_proj,
joiner_decoder_proj=joiner_decoder_proj,
encoder_out=encoder_out_fr,
encoder_out_lens=encoder_out_lens_fr,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -55,9 +55,9 @@ import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from frame_reducer import FrameReducer
from joiner import Joiner
from lconv import LConv
from frame_reducer import FrameReducer
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
@ -1063,10 +1063,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds

View File

@ -1049,10 +1049,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds

View File

@ -421,13 +421,13 @@ class Zipformer(EncoderInterface):
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
On e.g. 15% of frames, these masks will zero out all encoder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
a smaller encoder dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_subsampling_factor times.
mask values repeated self.zipformer_downsampling_factors times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
@ -1687,8 +1687,8 @@ class RelPositionalEncoding(torch.nn.Module):
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use positive relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x_size_left, self.d_model)
pe_negative = torch.zeros(x_size_left, self.d_model)
@ -1778,10 +1778,10 @@ class RelPositionMultiheadAttention(nn.Module):
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5, dividing it between the query and key.
in_proj_dim = (
2 * attention_dim
+ attention_dim // 2 # query, key
+ pos_dim * num_heads # value
) # positional encoding query
2 * attention_dim # query, key
+ attention_dim // 2 # value
+ pos_dim * num_heads # positional encoding query
)
self.in_proj = ScaledLinear(
embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25
@ -2536,7 +2536,7 @@ class FeedforwardModule(nn.Module):
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Zipformer model.
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/zipformer/convolution.py
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
Args:
channels (int): The number of channels of conv layers.

View File

@ -1154,10 +1154,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
train_cuts = filter_short_and_long_utterances(train_cuts, sp)

View File

@ -30,6 +30,7 @@ import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
@ -645,8 +646,23 @@ def run(rank, world_size, args):
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
return 1.0 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,

View File

@ -1,3 +1,3 @@
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/timit/index.html>
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/timit/index.html>
for how to run models in this recipe.

View File

@ -10,5 +10,5 @@ get the following WER:
```
Please refer to
<https://icefall.readthedocs.io/en/latest/recipes/yesno/index.html>
<https://icefall.readthedocs.io/en/latest/recipes/Non-streaming-ASR/yesno/index.html>
for detailed instructions.