Add pruned RNN-T for aishell. (#436)

* Add pruned RNN-T for aishell.

* support torch script.

* Update CI.

* Minor fixes.

* Add links to sherpa.
This commit is contained in:
Fangjun Kuang 2022-06-21 21:17:22 +08:00 committed by GitHub
parent d3daeaf5cd
commit 7100c33820
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 3055 additions and 18 deletions

View File

@ -0,0 +1,86 @@
#!/usr/bin/env bash
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
cd egs/aishell/ASR
git lfs install
fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests
log "Downloading pre-commputed fbank from $fbank_url"
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
ln -s $PWD/aishell-test-dev-manifests/data .
log "Downloading pre-trained model from $repo_url"
repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
git clone $repo_url
repo=$(basename $repo_url)
log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-epoch-29-avg-5-torch-1.10.pt pretrained.pt
popd
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"
./pruned_transducer_stateless3/pretrained.py \
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$rep/test_wavs/BAC009S0764W0123.wav
done
for method in modified_beam_search beam_search fast_beam_search; do
log "$method"
./pruned_transducer_stateless3/pretrained.py \
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$rep/test_wavs/BAC009S0764W0123.wav
done
echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then
mkdir -p pruned_transducer_stateless3/exp
ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_char data/
ls -lh data
ls -lh pruned_transducer_stateless3/exp
log "Decoding test and dev"
# use a small value for decoding with CPU
max_duration=100
for method in greedy_search fast_beam_search modified_beam_search; do
log "Decoding with $method"
./pruned_transducer_stateless3/decode.py \
--decoding-method $method \
--epoch 999 \
--avg 1 \
--max-duration $max_duration \
--exp-dir pruned_transducer_stateless3/exp
done
rm pruned_transducer_stateless3/exp/*.pt
fi

View File

@ -0,0 +1,119 @@
# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-aishell-2022-06-20
# pruned RNN-T + reworked model with random combiner
# https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20
on:
push:
branches:
- master
pull_request:
types: [labeled]
schedule:
# minute (0-59)
# hour (0-23)
# day of the month (1-31)
# month (1-12)
# day of the week (0-6)
# nightly build at 15:50 UTC time every day
- cron: "50 15 * * *"
jobs:
run_aishell_2022_06_20:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
cache-dependency-path: '**/requirements-ci.txt'
- name: Install Python dependencies
run: |
grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
- name: Cache kaldifeat
id: my-cache
uses: actions/cache@v2
with:
path: |
~/tmp/kaldifeat
key: cache-tmp-${{ matrix.python-version }}
- name: Install kaldifeat
if: steps.my-cache.outputs.cache-hit != 'true'
shell: bash
run: |
.github/scripts/install-kaldifeat.sh
- name: Inference with pre-trained model
shell: bash
env:
GITHUB_EVENT_NAME: ${{ github.event_name }}
GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }}
run: |
sudo apt-get -qq install git-lfs tree sox
export PYTHONPATH=$PWD:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh
- name: Display decoding results for aishell pruned_transducer_stateless3
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
shell: bash
run: |
cd egs/aishell/ASR/
tree ./pruned_transducer_stateless3/exp
cd pruned_transducer_stateless3
echo "results for pruned_transducer_stateless3"
echo "===greedy search==="
find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
echo "===fast_beam_search==="
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
echo "===modified beam search==="
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2
find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2
- name: Upload decoding results for aishell pruned_transducer_stateless3
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
with:
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-06-20
path: egs/aishell/ASR/pruned_transducer_stateless3/exp/

View File

@ -4,6 +4,8 @@
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/aishell/index.html>
for how to run models in this recipe.
# Transducers
There are various folders containing the name `transducer` in this folder.
@ -14,6 +16,7 @@ The following table lists the differences among them.
| `transducer_stateless` | Conformer | Embedding + Conv1d | with `k2.rnnt_loss` |
| `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` |
| `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data |
| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data|
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1,10 +1,93 @@
## Results
### Aishell training result(Transducer-stateless)
### Aishell training result(Stateless Transducer)
#### Pruned transducer stateless 3
See <https://github.com/k2-fsa/icefall/pull/436>
[./pruned_transducer_stateless3](./pruned_transducer_stateless3)
It uses pruned RNN-T.
| | test | dev | comment |
|------------------------|------|------|---------------------------------------|
| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 |
| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 |
| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 |
Training command is:
```bash
./prepare.sh
./prepare_aidatatang_200zh.sh
export CUDA_VISIBLE_DEVICES="4,5,6,7"
./pruned_transducer_stateless3/train.py \
--exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \
--world-size 4 \
--max-duration 200 \
--datatang-prob 0.5 \
--start-epoch 1 \
--num-epochs 30 \
--use-fp16 1 \
--num-encoder-layers 12 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--context-size 1 \
--decoder-dim 512 \
--joiner-dim 512 \
--master-port 12356
```
**Caution**: It uses `--context-size=1`.
The tensorboard log is available at
<https://tensorboard.dev/experiment/OKKacljwR6ik7rbDr5gMqQ>
The decoding command is:
```bash
for epoch in 29; do
for avg in 5; do
for m in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless3/decode.py \
--exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \
--epoch $epoch \
--avg $avg \
--use-averaged-model 1 \
--max-duration 600 \
--decoding-method $m \
--num-encoder-layers 12 \
--dim-feedforward 2048 \
--nhead 8 \
--context-size 1 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512
done
done
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20>
We have a tutorial in [sherpa](https://github.com/k2-fsa/sherpa) about how
to use the pre-trained model for non-streaming ASR. See
<https://k2-fsa.github.io/sherpa/offline_asr/conformer/aishell.html>
#### 2022-03-01
[./transducer_stateless_modified-2](./transducer_stateless_modified-2)
It uses [optimized_transducer](https://github.com/csukuangfj/optimized_transducer)
for computing RNN-T loss.
Stateless transducer + modified transducer + using [aidatatang_200zh](http://www.openslr.org/62/) as extra training data.

View File

@ -0,0 +1 @@
../transducer_stateless_modified-2/aidatatang_200zh.py

View File

@ -0,0 +1 @@
../transducer_stateless_modified-2/aishell.py

View File

@ -0,0 +1 @@
../transducer_stateless_modified-2/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py

View File

@ -0,0 +1,637 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless3/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from aishell import AIShell
from asr_datamodule import AsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
token_table: k2.SymbolTable,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
token_table:
It maps token ID to a string.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_tokens.append(hyp)
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
token_table:
It maps a token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
token_table=token_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
),
strict=False,
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
asr_datamodule = AsrDataModule(args)
aishell = AIShell(manifest_dir=args.manifest_dir)
test_cuts = aishell.test_cuts()
dev_cuts = aishell.valid_cuts()
test_dl = asr_datamodule.test_dataloaders(test_cuts)
dev_dl = asr_datamodule.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_dls = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
token_table=lexicon.token_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1

View File

@ -0,0 +1,277 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--jit 0 \
--epoch 29 \
--avg 5
It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt
To use the generated file with `pruned_transducer_stateless3/decode.py`,
you can do::
cd /path/to/exp_dir
ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt
cd /path/to/egs/aishell/ASR
./pruned_transducer_stateless3/decode.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=29,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=Path,
default=Path("pruned_transducer_stateless3/exp"),
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,236 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
decoder_datatang: Optional[nn.Module] = None,
joiner_datatang: Optional[nn.Module] = None,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim). Its output shape is (N, T, U, vocab_size).
Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
encoder_dim:
Output dimension of the encoder network.
decoder_dim:
Output dimension of the decoder network.
joiner_dim:
Input dimension of the joiner network.
vocab_size:
Output dimension of the joiner network.
decoder_datatang:
Optional. The decoder network for the aidatatang_200zh dataset.
joiner_datatang:
Optional. The joiner network for the aidatatang_200zh dataset.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.decoder_datatang = decoder_datatang
self.joiner_datatang = joiner_datatang
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if decoder_datatang is not None:
self.simple_am_proj_datatang = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj_datatang = ScaledLinear(decoder_dim, vocab_size)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
aishell: bool = True,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
aishell:
True to use the decoder and joiner for the aishell dataset.
False to use the decoder and joiner for the aidatatang_200zh
dataset.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
assert torch.all(encoder_out_lens > 0)
if aishell:
decoder = self.decoder
simple_lm_proj = self.simple_lm_proj
simple_am_proj = self.simple_am_proj
joiner = self.joiner
else:
decoder = self.decoder_datatang
simple_lm_proj = self.simple_lm_proj_datatang
simple_am_proj = self.simple_am_proj_datatang
joiner = self.joiner_datatang
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = simple_lm_proj(decoder_out)
am = simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=joiner.encoder_proj(encoder_out),
lm=joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,337 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless3/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
logging.info(f"Using {params.method}")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -405,7 +405,7 @@ def compute_loss(
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Compute RNN-T loss given the model and its inputs.
Args:
params:

View File

@ -73,6 +73,9 @@ class Decoder(nn.Module):
groups=decoder_dim,
bias=False,
)
else:
# It is to support torch script
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -117,10 +117,7 @@ class Conformer(EncoderInterface):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
@ -293,8 +290,10 @@ class ConformerEncoder(nn.Module):
)
self.num_layers = num_layers
assert len(set(aux_layers)) == len(aux_layers)
assert num_layers - 1 not in aux_layers
self.aux_layers = set(aux_layers + [num_layers - 1])
self.aux_layers = aux_layers + [num_layers - 1]
num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine(
@ -1154,7 +1153,7 @@ class RandomCombine(nn.Module):
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training:
if not self.training or torch.jit.is_scripting():
return inputs[-1]
# Shape of weights: (*, num_inputs)
@ -1162,8 +1161,22 @@ class RandomCombine(nn.Module):
num_frames = inputs[0].numel() // num_channels
mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
if False:
# It throws the following error for torch 1.6.0 when using
# torch script.
#
# Expected integer literal for index. ModuleList/Sequential
# indexing is only supported with integer literals. Enumeration is
# supported, e.g. 'for index, v in enumerate(self): ...':
# for i in range(num_inputs - 1):
# mod_inputs.append(self.linear[i](inputs[i]))
assert False
else:
for i, linear in enumerate(self.linear):
if i < num_inputs - 1:
mod_inputs.append(linear(inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim
@ -1181,11 +1194,13 @@ class RandomCombine(nn.Module):
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
if __name__ == "__main__":
# for testing only...
print("Weights = ", weights.reshape(num_frames, num_inputs))
ans = ans.reshape(inputs[0].shape[:-1] + [num_channels])
# The following if causes errors for torch script in torch 1.6.0
# if __name__ == "__main__":
# # for testing only...
# print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(

View File

@ -146,8 +146,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -246,12 +244,15 @@ def main():
)
)
model.eval()
model.to("cpu")
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"