Merge branch 'k2-fsa:master' into gigaspeech_streaming

This commit is contained in:
Guanbo Wang 2022-08-22 16:10:24 -04:00 committed by GitHub
commit 5ea4d94ac0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
136 changed files with 11218 additions and 388 deletions

View File

@ -9,7 +9,8 @@ per-file-ignores =
egs/*/ASR/pruned_transducer_stateless*/*.py: E501,
egs/*/ASR/*/optim.py: E501,
egs/*/ASR/*/scaling.py: E501,
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203,
egs/librispeech/ASR/lstm_transducer_stateless/*.py: E501, E203
egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203
egs/librispeech/ASR/conformer_ctc2/*py: E501,
egs/librispeech/ASR/RESULTS.md: E999,

View File

@ -22,8 +22,80 @@ ls -lh $repo/test_wavs/*.wav
pushd $repo/exp
ln -s pretrained-iter-1224000-avg-14.pt pretrained.pt
ln -s pretrained-iter-1224000-avg-14.pt epoch-99.pt
popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--onnx 1
log "Export to torchscript model"
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit 1
./pruned_transducer_stateless3/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--jit-trace 1
ls -lh $repo/exp/*.onnx
ls -lh $repo/exp/*.pt
log "Decode with ONNX models"
./pruned_transducer_stateless3/onnx_check.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-encoder-filename $repo/exp/encoder.onnx \
--onnx-decoder-filename $repo/exp/decoder.onnx \
--onnx-joiner-filename $repo/exp/joiner.onnx
./pruned_transducer_stateless3/onnx_check_all_in_one.py \
--jit-filename $repo/exp/cpu_jit.pt \
--onnx-all-in-one-filename $repo/exp/all_in_one.onnx
./pruned_transducer_stateless3/onnx_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
log "Decode with models exported by torch.jit.trace()"
./pruned_transducer_stateless3/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_trace.pt \
--decoder-model-filename $repo/exp/decoder_jit_trace.pt \
--joiner-model-filename $repo/exp/joiner_jit_trace.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
log "Decode with models exported by torch.jit.script()"
./pruned_transducer_stateless3/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo/exp/encoder_jit_script.pt \
--decoder-model-filename $repo/exp/decoder_jit_script.pt \
--joiner-model-filename $repo/exp/joiner_jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"

View File

@ -35,7 +35,7 @@ on:
jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:

View File

@ -79,6 +79,7 @@ RUN git clone https://github.com/k2-fsa/k2.git /opt/k2 && \
cd -
# install lhotse
RUN pip install torchaudio==0.7.2
RUN pip install git+https://github.com/lhotse-speech/lhotse
#RUN pip install lhotse

View File

@ -367,6 +367,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -379,8 +380,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
@ -405,6 +406,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -528,6 +530,8 @@ def main():
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results.
args.return_cuts = True
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
dev = "dev"

View File

@ -81,6 +81,58 @@ We have a tutorial in [sherpa](https://github.com/k2-fsa/sherpa) about how
to use the pre-trained model for non-streaming ASR. See
<https://k2-fsa.github.io/sherpa/offline_asr/conformer/aishell.html>
#### Pruned transducer stateless 2
See https://github.com/k2-fsa/icefall/pull/536
[./pruned_transducer_stateless2](./pruned_transducer_stateless2)
It uses pruned RNN-T.
| | test | dev | comment |
| -------------------- | ---- | ---- | -------------------------------------- |
| greedy search | 5.20 | 4.78 | --epoch 72 --avg 14 --max-duration 200 |
| modified beam search | 5.07 | 4.63 | --epoch 72 --avg 14 --max-duration 200 |
| fast beam search | 5.13 | 4.70 | --epoch 72 --avg 14 --max-duration 200 |
Training command is:
```bash
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1"
./pruned_transducer_stateless2/train.py \
--world-size 2 \
--num-epochs 90 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--max-duration 200 \
```
The tensorboard log is available at
https://tensorboard.dev/experiment/QI3PVzrGRrebxpbWUPwmkA/
The decoding command is:
```bash
for m in greedy_search modified_beam_search fast_beam_search ; do
./pruned_transducer_stateless2/decode.py \
--epoch 72 \
--avg 14 \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 200 \
--decoding-method $m
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/teapoly/icefall-aishell-pruned-transducer-stateless2-2022-08-18>
#### 2022-03-01
[./transducer_stateless_modified-2](./transducer_stateless_modified-2)

View File

@ -374,6 +374,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -389,9 +390,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
@ -419,6 +420,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -429,7 +431,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -537,6 +541,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -386,6 +386,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -401,9 +402,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
@ -431,6 +432,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -441,7 +443,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log
@ -556,6 +560,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/conformer.py

View File

@ -0,0 +1,573 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
--epoch 84 \
--avg 25 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless2/decode.py \
--epoch 84 \
--avg 25 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless2/decode.py \
--epoch 84 \
--avg 25 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless2/decode.py \
--epoch 84 \
--avg 25 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
token_table: k2.SymbolTable,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
token_table:
It maps token ID to a string.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
hyp_tokens = []
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_tokens.append(hyp)
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
token_table: k2.SymbolTable,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
token_table:
It maps a token ID to a string.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
token_table=token_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
dev_cuts = aishell.valid_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
dev_dl = aishell.test_dataloaders(dev_cuts)
test_sets = ["test", "dev"]
test_dls = [test_dl, dev_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
token_table=lexicon.token_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1,217 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--jit 0 \
--epoch 29 \
--avg 5
It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt
To use the generated file with `pruned_transducer_stateless2/decode.py`,
you can do::
cd /path/to/exp_dir
ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt
cd /path/to/egs/aishell/ASR
./pruned_transducer_stateless2/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100 \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=29,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=Path,
default=Path("pruned_transducer_stateless2/exp"),
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to("cpu")
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = (
params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt"
)
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = (
params.exp_dir
/ f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt"
)
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,337 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint /path/to/pretrained.pt \
--lang-dir /path/to/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from pathlib import Path
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help="The lang dir",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="Maximum number of symbols per frame. "
"Use only when --method is greedy_search",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lens = [f.size(0) for f in features]
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens
)
num_waves = encoder_out.size(0)
hyp_list = []
logging.info(f"Using {params.method}")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.method}"
)
hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py

File diff suppressed because it is too large Load Diff

View File

@ -377,6 +377,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -389,9 +390,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -416,6 +417,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -427,7 +429,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -606,6 +610,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args)
aishell = AIShell(manifest_dir=args.manifest_dir)
test_cuts = aishell.test_cuts()

View File

@ -22,8 +22,12 @@
Usage:
./prepare.sh
# If you use a non-zero value for --datatang-prob, you also need to run
./prepare_aidatatang_200zh.sh
If you use --datatang-prob=0, then you don't need to run the above script.
export CUDA_VISIBLE_DEVICES="0,1,2,3"
@ -62,7 +66,6 @@ import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from aidatatang_200zh import AIDatatang200zh
from aishell import AIShell
from asr_datamodule import AsrDataModule
@ -344,7 +347,7 @@ def get_parser():
parser.add_argument(
"--datatang-prob",
type=float,
default=0.2,
default=0.0,
help="""The probability to select a batch from the
aidatatang_200zh dataset.
If it is set to 0, you don't need to download the data
@ -945,7 +948,10 @@ def train_one_epoch(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if datatang_train_dl is not None:
loss_value = tot_loss["loss"] / tot_loss["frames"]
else:
loss_value = aishell_tot_loss["loss"] / aishell_tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
@ -1032,7 +1038,16 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if params.datatang_prob > 0:
find_unused_parameters = True
else:
find_unused_parameters = False
model = DDP(
model,
device_ids=[rank],
find_unused_parameters=find_unused_parameters,
)
optimizer = Eve(model.parameters(), lr=params.initial_lr)

View File

@ -241,6 +241,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -253,9 +254,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
@ -278,6 +279,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -287,7 +289,9 @@ def save_results(
# We compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results_char)
test_set_wers[key] = wer
@ -365,6 +369,8 @@ def main():
model.to(device)
model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -38,8 +38,8 @@ from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
write_error_stats,
str2bool,
write_error_stats,
)
@ -296,6 +296,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -307,9 +308,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -334,6 +335,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
# The following prints out WERs, per-word error statistics and aligned
@ -344,7 +346,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -438,6 +442,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -341,6 +341,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -353,9 +354,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -380,6 +381,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -391,7 +393,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -496,6 +500,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args)
aishell = AIShell(manifest_dir=args.manifest_dir)
test_cuts = aishell.test_cuts()

View File

@ -345,6 +345,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -357,9 +358,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -384,6 +385,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -395,7 +397,9 @@ def save_results(
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((list("".join(res[0])), list("".join(res[1]))))
results_char.append(
(res[0], list("".join(res[1])), list("".join(res[2])))
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=True
@ -498,6 +502,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)

View File

@ -514,6 +514,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -527,8 +528,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
@ -553,6 +554,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -756,6 +758,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell2 = AiShell2AsrDataModule(args)
valid_cuts = aishell2.valid_cuts()

View File

@ -378,6 +378,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -390,8 +391,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
@ -416,6 +417,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -607,6 +609,8 @@ def main():
c.supervisions[0].text = text_normalize(text)
return c
# we need cut ids to display recognition results.
args.return_cuts = True
aishell4 = Aishell4AsrDataModule(args)
test_cuts = aishell4.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)

View File

@ -367,6 +367,7 @@ def decode_dataset(
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text).replace(" ", "")) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -379,8 +380,8 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
@ -405,6 +406,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -535,6 +537,8 @@ def main():
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results.
args.return_cuts = True
alimeeting = AlimeetingAsrDataModule(args)
dev = "eval"

View File

@ -451,6 +451,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -469,9 +470,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
@ -512,6 +513,7 @@ def save_results(
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -676,6 +678,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args)
dev_cuts = gigaspeech.dev_cuts()

View File

@ -374,6 +374,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -386,9 +387,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -414,6 +415,7 @@ def save_results(
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -544,6 +546,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
gigaspeech = GigaSpeechAsrDataModule(args)
dev_cuts = gigaspeech.dev_cuts()

View File

@ -25,6 +25,7 @@ The following table lists the differences among them.
| `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
| `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
| `conv_emformer_transducer_stateless2` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer with simplified memory for streaming ASR + mechanisms in reworked model |
| `lstm_transducer_stateless` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1,5 +1,91 @@
## Results
#### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T)
[lstm_transducer_stateless](./lstm_transducer_stateless)
It implements LSTM model with mechanisms in reworked model for streaming ASR.
See <https://github.com/k2-fsa/icefall/pull/479> for more details.
#### training on full librispeech
This model contains 12 encoder layers (LSTM module + Feedforward module). The number of model parameters is 84689496.
The WERs are:
| | test-clean | test-other | comment | decoding mode |
|-------------------------------------|------------|------------|----------------------|----------------------|
| greedy search (max sym per frame 1) | 3.81 | 9.73 | --epoch 35 --avg 15 | simulated streaming |
| greedy search (max sym per frame 1) | 3.78 | 9.79 | --epoch 35 --avg 15 | streaming |
| fast beam search | 3.74 | 9.59 | --epoch 35 --avg 15 | simulated streaming |
| fast beam search | 3.73 | 9.61 | --epoch 35 --avg 15 | streaming |
| modified beam search | 3.64 | 9.55 | --epoch 35 --avg 15 | simulated streaming |
| modified beam search | 3.65 | 9.51 | --epoch 35 --avg 15 | streaming |
Note: `simulated streaming` indicates feeding full utterance during decoding, while `streaming` indicates feeding certain number of frames at each time.
The training command is:
```bash
./lstm_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 35 \
--start-epoch 1 \
--exp-dir lstm_transducer_stateless/exp \
--full-libri 1 \
--max-duration 500 \
--master-port 12321 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024
```
The tensorboard log can be found at
<https://tensorboard.dev/experiment/FWrM20mjTeWo6dTpFYOsYQ/>
The simulated streaming decoding command using greedy search, fast beam search, and modified beam search is:
```bash
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
./lstm_transducer_stateless/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir lstm_transducer_stateless/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $decoding_method \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
```
The streaming decoding command using greedy search, fast beam search, and modified beam search is:
```bash
for decoding_method in greedy_search fast_beam_search modified_beam_search; do
./lstm_transducer_stateless/streaming_decode.py \
--epoch 35 \
--avg 15 \
--exp-dir lstm_transducer_stateless/exp \
--max-duration 600 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method $decoding_method \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--beam-size 4
done
```
Pretrained models, training logs, decoding logs, and decoding results
are available at
<https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18>
#### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T 2)
[conv_emformer_transducer_stateless2](./conv_emformer_transducer_stateless2)

View File

@ -525,6 +525,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -544,9 +545,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
@ -586,6 +587,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -779,6 +781,8 @@ def main():
)
rnn_lm_model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -447,6 +447,17 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
info["utterances"] = feature.size(0)
# averaged input duration in frames over utterances
info["utt_duration"] = supervisions["num_frames"].sum().item()
# averaged padding proportion over utterances
info["utt_pad_proportion"] = (
((feature.size(1) - supervisions["num_frames"]) / feature.size(1))
.sum()
.item()
)
return loss, info

View File

@ -31,14 +31,13 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.decode import (
get_lattice,
nbest_decoding,
@ -633,6 +632,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -652,9 +652,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
@ -694,6 +694,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -956,6 +957,8 @@ def main():
)
rnn_lm_model.eval()
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -605,6 +605,15 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
info["utterances"] = feature.size(0)
# averaged input duration in frames over utterances
info["utt_duration"] = feature_lens.sum().item()
# averaged padding proportion over utterances
info["utt_pad_proportion"] = (
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
)
return loss, info

View File

@ -449,6 +449,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -466,9 +467,9 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
@ -496,6 +497,7 @@ def save_results(
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
@ -661,6 +663,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip

View File

@ -403,6 +403,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -415,9 +416,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -442,6 +443,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -624,6 +626,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -29,6 +29,7 @@ class Stream(object):
def __init__(
self,
params: AttributeDict,
cut_id: str,
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
LOG_EPS: float = math.log(1e-10),
@ -44,6 +45,7 @@ class Stream(object):
The device to run this stream.
"""
self.LOG_EPS = LOG_EPS
self.cut_id = cut_id
# Containing attention caches and convolution caches
self.states: Optional[
@ -138,6 +140,10 @@ class Stream(object):
"""Return True if all feature frames are processed."""
return self._done
@property
def id(self) -> str:
return self.cut_id
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.decoding_method == "greedy_search":

View File

@ -74,7 +74,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
from lhotse import CutSet
import numpy as np
import sentencepiece as spm
import torch
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
@ -678,6 +678,7 @@ def decode_dataset(
# Each utterance has a Stream.
stream = Stream(
params=params,
cut_id=cut.id,
decoding_graph=decoding_graph,
device=device,
LOG_EPS=LOG_EPSILON,
@ -711,6 +712,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
@ -731,6 +733,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)

View File

@ -403,6 +403,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -415,9 +416,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -442,6 +443,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -624,6 +626,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -74,7 +74,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
from lhotse import CutSet
import numpy as np
import sentencepiece as spm
import torch
@ -83,6 +82,7 @@ from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from emformer import LOG_EPSILON, stack_states, unstack_states
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
@ -678,6 +678,7 @@ def decode_dataset(
# Each utterance has a Stream.
stream = Stream(
params=params,
cut_id=cut.id,
decoding_graph=decoding_graph,
device=device,
LOG_EPS=LOG_EPSILON,
@ -711,6 +712,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
@ -731,6 +733,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/__init__.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

View File

@ -0,0 +1,818 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./lstm_transducer_stateless/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./lstm_transducer_stateless/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./lstm_transducer_stateless/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./lstm_transducer_stateless/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./lstm_transducer_stateless/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./lstm_transducer_stateless/decode.py \
--epoch 35 \
--avg 15 \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="lstm_transducer_stateless/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
# tail padding here to alleviate the tail deletion problem
num_tail_padded_frames = 35
feature = torch.nn.functional.pad(
feature,
(0, 0, 0, num_tail_padded_frames),
mode="constant",
value=LOG_EPS,
)
feature_lens += num_tail_padded_frames
encoder_out, encoder_out_lens, _ = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,388 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.trace()
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 35 \
--avg 10 \
--jit-trace 1
It will generate 3 files: `encoder_jit_trace.pt`,
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
(2) Export `model.state_dict()`
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 35 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `lstm_transducer_stateless/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./lstm_transducer_stateless/decode.py \
--exp-dir ./lstm_transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18
# You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless3/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit-trace",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace.
It will generate 3 files:
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
Check ./jit_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
states = encoder_model.get_init_states()
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
else:
logging.info("Not using torchscript")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,322 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, either exported by `torch.jit.trace()`
or by `torch.jit.script()`, and uses them to decode waves.
You can use the following command to get the exported models:
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit-trace 1
Usage of this script:
./lstm_transducer_stateless/jit_pretrained.py \
--encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \
--decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \
--joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
encoder = torch.jit.load(args.encoder_model_filename)
decoder = torch.jit.load(args.decoder_model_filename)
joiner = torch.jit.load(args.joiner_model_filename)
encoder.eval()
decoder.eval()
joiner.eval()
encoder.to(device)
decoder.to(device)
joiner.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
states = encoder.get_init_states(batch_size=features.size(0), device=device)
encoder_out, encoder_out_lens, _ = encoder(
x=features,
x_lens=feature_lengths,
states=states,
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -0,0 +1,842 @@
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
ScaledConv2d,
ScaledLinear,
ScaledLSTM,
)
from torch import nn
LOG_EPSILON = math.log(1e-10)
def unstack_states(
states: Tuple[torch.Tensor, torch.Tensor]
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Unstack the lstm states corresponding to a batch of utterances into a list
of states, where the i-th entry is the state from the i-th utterance.
Args:
states:
A tuple of 2 elements.
``states[0]`` is the lstm hidden states, of a batch of utterance.
``states[1]`` is the lstm cell states, of a batch of utterances.
Returns:
A list of states.
``states[i]`` is a tuple of 2 elememts of i-th utterance.
``states[i][0]`` is the lstm hidden states of i-th utterance.
``states[i][1]`` is the lstm cell states of i-th utterance.
"""
hidden_states, cell_states = states
list_hidden_states = hidden_states.unbind(dim=1)
list_cell_states = cell_states.unbind(dim=1)
ans = [
(h.unsqueeze(1), c.unsqueeze(1))
for (h, c) in zip(list_hidden_states, list_cell_states)
]
return ans
def stack_states(
states_list: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Stack list of lstm states corresponding to separate utterances into a single
lstm state so that it can be used as an input for lstm when those utterances
are formed into a batch.
Args:
state_list:
Each element in state_list corresponds to the lstm state for a single
utterance.
``states[i]`` is a tuple of 2 elememts of i-th utterance.
``states[i][0]`` is the lstm hidden states of i-th utterance.
``states[i][1]`` is the lstm cell states of i-th utterance.
Returns:
A new state corresponding to a batch of utterances.
It is a tuple of 2 elements.
``states[0]`` is the lstm hidden states, of a batch of utterance.
``states[1]`` is the lstm cell states, of a batch of utterances.
"""
hidden_states = torch.cat([s[0] for s in states_list], dim=1)
cell_states = torch.cat([s[1] for s in states_list], dim=1)
ans = (hidden_states, cell_states)
return ans
class RNN(EncoderInterface):
"""
Args:
num_features (int):
Number of input features.
subsampling_factor (int):
Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa
d_model (int):
Output dimension (default=512).
dim_feedforward (int):
Feedforward dimension (default=2048).
rnn_hidden_size (int):
Hidden dimension for lstm layers (default=1024).
num_encoder_layers (int):
Number of encoder layers (default=12).
dropout (float):
Dropout rate (default=0.1).
layer_dropout (float):
Dropout value for model-level warmup (default=0.075).
aux_layer_period (int):
Period of auxiliary layers used for random combiner during training.
If set to 0, will not use the random combiner (Default).
You can set a positive integer to use the random combiner, e.g., 3.
"""
def __init__(
self,
num_features: int,
subsampling_factor: int = 4,
d_model: int = 512,
dim_feedforward: int = 2048,
rnn_hidden_size: int = 1024,
num_encoder_layers: int = 12,
dropout: float = 0.1,
layer_dropout: float = 0.075,
aux_layer_period: int = 0,
) -> None:
super(RNN, self).__init__()
self.num_features = num_features
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.num_encoder_layers = num_encoder_layers
self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
encoder_layer = RNNEncoderLayer(
d_model=d_model,
dim_feedforward=dim_feedforward,
rnn_hidden_size=rnn_hidden_size,
dropout=dropout,
layer_dropout=layer_dropout,
)
self.encoder = RNNEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(
range(
num_encoder_layers // 3,
num_encoder_layers - 1,
aux_layer_period,
)
)
if aux_layer_period > 0
else None,
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C), where N is the batch size,
T is the sequence length, C is the feature dimension.
x_lens:
A tensor of shape (N,), containing the number of frames in `x`
before padding.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
A tuple of 3 tensors:
- embeddings: its shape is (N, T', d_model), where T' is the output
sequence lengths.
- lengths: a tensor of shape (batch_size,) containing the number of
frames in `embeddings` before padding.
- updated states, whose shape is the same as the input states.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 3) >> 1) - 1) >> 1
if not torch.jit.is_tracing():
assert x.size(0) == lengths.max().item()
if states is None:
x = self.encoder(x, warmup=warmup)[0]
# torch.jit.trace requires returned types to be the same as annotated # noqa
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_encoder_layers,
x.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_encoder_layers,
x.size(1),
self.rnn_hidden_size,
)
x, new_states = self.encoder(x, states)
x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
return x, lengths, new_states
@torch.jit.export
def get_init_states(
self, batch_size: int = 1, device: torch.device = torch.device("cpu")
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get model initial states."""
# for rnn hidden states
hidden_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.d_model), device=device
)
cell_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.rnn_hidden_size),
device=device,
)
return (hidden_states, cell_states)
class RNNEncoderLayer(nn.Module):
"""
RNNEncoderLayer is made up of lstm and feedforward networks.
Args:
d_model:
The number of expected features in the input (required).
dim_feedforward:
The dimension of feedforward network model (default=2048).
rnn_hidden_size:
The hidden dimension of rnn layer.
dropout:
The dropout value (default=0.1).
layer_dropout:
The dropout value for model-level warmup (default=0.075).
"""
def __init__(
self,
d_model: int,
dim_feedforward: int,
rnn_hidden_size: int,
dropout: float = 0.1,
layer_dropout: float = 0.075,
) -> None:
super(RNNEncoderLayer, self).__init__()
self.layer_dropout = layer_dropout
self.d_model = d_model
self.rnn_hidden_size = rnn_hidden_size
assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model)
self.lstm = ScaledLSTM(
input_size=d_model,
hidden_size=rnn_hidden_size,
proj_size=d_model if rnn_hidden_size > d_model else 0,
num_layers=1,
dropout=0.0,
)
self.feed_forward = nn.Sequential(
ScaledLinear(d_model, dim_feedforward),
ActivationBalancer(channel_dim=-1),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.norm_final = BasicNorm(d_model)
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
self.balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (1, N, d_model);
states[1] is the cell states of all layers,
with shape of (1, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
src_orig = src
warmup_scale = min(0.1 + warmup, 1.0)
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
if self.training:
alpha = (
warmup_scale
if torch.rand(()).item() <= (1.0 - self.layer_dropout)
else 0.1
)
else:
alpha = 1.0
# lstm module
if states is None:
src_lstm = self.lstm(src)[0]
# torch.jit.trace requires returned types be the same as annotated
new_states = (torch.empty(0), torch.empty(0))
else:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (1, src.size(1), self.d_model)
# for cell state
assert states[1].shape == (1, src.size(1), self.rnn_hidden_size)
src_lstm, new_states = self.lstm(src, states)
src = src + self.dropout(src_lstm)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src, new_states
class RNNEncoder(nn.Module):
"""
RNNEncoder is a stack of N encoder layers.
Args:
encoder_layer:
An instance of the RNNEncoderLayer() class (required).
num_layers:
The number of sub-encoder-layers in the encoder (required).
"""
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
aux_layers: Optional[List[int]] = None,
) -> None:
super(RNNEncoder, self).__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.d_model = encoder_layer.d_model
self.rnn_hidden_size = encoder_layer.rnn_hidden_size
self.aux_layers: List[int] = []
self.combiner: Optional[nn.Module] = None
if aux_layers is not None:
assert len(set(aux_layers)) == len(aux_layers)
assert num_layers - 1 not in aux_layers
self.aux_layers = aux_layers + [num_layers - 1]
self.combiner = RandomCombine(
num_inputs=len(self.aux_layers),
final_weight=0.5,
pure_prob=0.333,
stddev=2.0,
)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer in turn.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (num_layers, N, d_model);
states[1] is the cell states of all layers,
with shape of (num_layers, N, rnn_hidden_size).
warmup:
It controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
"""
if states is not None:
assert not self.training
assert len(states) == 2
if not torch.jit.is_tracing():
# for hidden state
assert states[0].shape == (
self.num_layers,
src.size(1),
self.d_model,
)
# for cell state
assert states[1].shape == (
self.num_layers,
src.size(1),
self.rnn_hidden_size,
)
output = src
outputs = []
new_hidden_states = []
new_cell_states = []
for i, mod in enumerate(self.layers):
if states is None:
output = mod(output, warmup=warmup)[0]
else:
layer_state = (
states[0][i : i + 1, :, :], # h: (1, N, d_model)
states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size)
)
output, (h, c) = mod(output, layer_state)
new_hidden_states.append(h)
new_cell_states.append(c)
if self.combiner is not None and i in self.aux_layers:
outputs.append(output)
if self.combiner is not None:
output = self.combiner(outputs)
if states is None:
new_states = (torch.empty(0), torch.empty(0))
else:
new_states = (
torch.cat(new_hidden_states, dim=0),
torch.cat(new_cell_states, dim=0),
)
return output, new_states
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = ((T-3)//2-1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >= 9, in_channels >= 9.
out_channels
Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
"""
assert in_channels >= 9
super().__init__()
self.conv = nn.Sequential(
ScaledConv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=0,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
ScaledConv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=2,
),
ActivationBalancer(channel_dim=1),
DoubleSwish(),
)
self.out = ScaledLinear(
layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels
)
# set learn_eps=False because out_norm is preceded by `out`, and `out`
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
Returns:
Return a tensor of shape (N, ((T-3)//2-1)//2, odim)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
x = self.conv(x)
# Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-3)//2-1))//2, odim)
x = self.out_norm(x)
x = self.out_balancer(x)
return x
class RandomCombine(nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(
self,
num_inputs: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
) -> None:
"""
Args:
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev
self.final_log_weight = (
torch.tensor(
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
)
.log()
.item()
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training or torch.jit.is_scripting():
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
# The following if causes errors for torch script in torch 1.6.0
# if __name__ == "__main__":
# # for testing only...
# print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)
def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa
nonfinal = torch.randint(
self.num_inputs - 1, (num_frames,), device=device
)
indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(
indexes, num_classes=self.num_inputs
).to(dtype=dtype)
return ans
def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa
)
num_inputs = 3
num_channels = 50
m = RandomCombine(
num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_random_combine_main():
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
feature_dim = 50
c = RNN(num_features=feature_dim, d_model=128)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
)
f # to remove flake8 warnings
if __name__ == "__main__":
feature_dim = 80
m = RNN(
num_features=feature_dim,
d_model=512,
rnn_hidden_size=1024,
dim_feedforward=2048,
num_encoder_layers=12,
)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.
f = m(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)
num_param = sum([p.numel() for p in m.parameters()])
print(f"Number of model parameters: {num_param}")
_test_random_combine_main()

View File

@ -0,0 +1,202 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction=reduction,
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1,352 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`.
Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by
./lstm_transducer_stateless/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens, _ = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless3/scaling_converter.py

View File

@ -0,0 +1,148 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple
import k2
import torch
from beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict
class Stream(object):
def __init__(
self,
params: AttributeDict,
cut_id: str,
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
LOG_EPS: float = math.log(1e-10),
) -> None:
"""
Args:
params:
It's the return value of :func:`get_params`.
cut_id:
The cut id of the current stream.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
device:
The device to run this stream.
LOG_EPS:
A float value used for padding.
"""
self.LOG_EPS = LOG_EPS
self.cut_id = cut_id
# Containing attention caches and convolution caches
self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# It uses different attributes for different decoding methods.
self.context_size = params.context_size
self.decoding_method = params.decoding_method
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[params.blank_id] * params.context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
)
self.hyp: Optional[List[int]] = None
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
self.ground_truth: str = ""
self.feature: Optional[torch.Tensor] = None
# Make sure all feature frames can be used.
# We aim to obtain 1 frame after subsampling.
self.chunk_length = params.subsampling_factor
self.pad_length = 5
self.num_frames = 0
self.num_processed_frames = 0
# After all feature frames are processed, we set this flag to True
self._done = False
def set_feature(self, feature: torch.Tensor) -> None:
assert feature.dim() == 2, feature.dim()
# tail padding here to alleviate the tail deletion problem
num_tail_padded_frames = 35
self.num_frames = feature.size(0) + num_tail_padded_frames
self.feature = torch.nn.functional.pad(
feature,
(0, 0, 0, self.pad_length + num_tail_padded_frames),
mode="constant",
value=self.LOG_EPS,
)
def get_feature_chunk(self) -> torch.Tensor:
"""Get a chunk of feature frames.
Returns:
A tensor of shape (ret_length, feature_dim).
"""
update_length = min(
self.num_frames - self.num_processed_frames, self.chunk_length
)
ret_length = update_length + self.pad_length
ret_feature = self.feature[
self.num_processed_frames : self.num_processed_frames + ret_length
]
# Cut off used frames.
# self.feature = self.feature[update_length:]
self.num_processed_frames += update_length
if self.num_processed_frames >= self.num_frames:
self._done = True
return ret_feature
@property
def id(self) -> str:
return self.cut_id
@property
def done(self) -> bool:
"""Return True if all feature frames are processed."""
return self._done
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.decoding_method == "greedy_search":
return self.hyp[self.context_size :]
elif self.decoding_method == "modified_beam_search":
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.context_size :]
else:
assert self.decoding_method == "fast_beam_search"
return self.hyp

View File

@ -0,0 +1,968 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./lstm_transducer_stateless/streaming_decode.py \
--epoch 35 \
--avg 10 \
--exp-dir lstm_transducer_stateless/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method greedy_search \
--use-averaged-model True
(2) modified beam search
./lstm_transducer_stateless/streaming_decode.py \
--epoch 35 \
--avg 10 \
--exp-dir lstm_transducer_stateless/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method modified_beam_search \
--use-averaged-model True \
--beam-size 4
(3) fast beam search
./lstm_transducer_stateless/streaming_decode.py \
--epoch 35 \
--avg 10 \
--exp-dir lstm_transducer_stateless/exp \
--num-decode-streams 2000 \
--num-encoder-layers 12 \
--rnn-hidden-size 1024 \
--decoding-method fast_beam_search \
--use-averaged-model True \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
import argparse
import logging
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from lstm import LOG_EPSILON, stack_states, unstack_states
from stream import Stream
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import one_best_decoding
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_emformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--sampling-rate",
type=float,
default=16000,
help="Sample rate of the audio",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded in parallel",
)
add_model_arguments(parser)
return parser
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
T = encoder_out.size(1)
encoder_out = model.joiner.encoder_proj(encoder_out)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (batch_size, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[Stream],
beam: int = 4,
):
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
beam:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, project_input=False
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
streams: List[Stream],
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
streams:
A list of stream objects.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
assert B == len(streams)
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyps[i]
def decode_one_chunk(
model: nn.Module,
streams: List[Stream],
params: AttributeDict,
decoding_graph: Optional[k2.Fsa] = None,
) -> List[int]:
"""
Args:
model:
The Transducer model.
streams:
A list of Stream objects.
params:
It is returned by :func:`get_params`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
A list of indexes indicating the finished streams.
"""
device = next(model.parameters()).device
feature_list = []
feature_len_list = []
state_list = []
num_processed_frames_list = []
for stream in streams:
# We should first get `stream.num_processed_frames`
# before calling `stream.get_feature_chunk()`
# since `stream.num_processed_frames` would be updated
num_processed_frames_list.append(stream.num_processed_frames)
feature = stream.get_feature_chunk()
feature_len = feature.size(0)
feature_list.append(feature)
feature_len_list.append(feature_len)
state_list.append(stream.states)
features = pad_sequence(
feature_list, batch_first=True, padding_value=LOG_EPSILON
).to(device)
feature_lens = torch.tensor(feature_len_list, device=device)
num_processed_frames = torch.tensor(
num_processed_frames_list, device=device
)
# Make sure it has at least 1 frame after subsampling
tail_length = params.subsampling_factor + 5
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPSILON,
)
# Stack states of all streams
states = stack_states(state_list)
encoder_out, encoder_out_lens, states = model.encoder(
x=features,
x_lens=feature_lens,
states=states,
)
if params.decoding_method == "greedy_search":
greedy_search(
model=model,
streams=streams,
encoder_out=encoder_out,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=streams,
encoder_out=encoder_out,
beam=params.beam_size,
)
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
# The rnnt_decoding_stream for fast_beam_search.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
processed_lens = (
num_processed_frames // params.subsampling_factor
+ encoder_out_lens
)
fast_beam_search_one_best(
model=model,
streams=streams,
encoder_out=encoder_out,
processed_lens=processed_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
# Update cached states of each stream
state_list = unstack_states(states)
for i, s in enumerate(state_list):
streams[i].states = s
finished_streams = [i for i, stream in enumerate(streams) if stream.done]
return finished_streams
def create_streaming_feature_extractor() -> Fbank:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return Fbank(opts)
def decode_dataset(
cuts: CutSet,
model: nn.Module,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
):
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The Transducer model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = next(model.parameters()).device
log_interval = 300
fbank = create_streaming_feature_extractor()
decode_results = []
streams = []
for num, cut in enumerate(cuts):
# Each utterance has a Stream.
stream = Stream(
params=params,
cut_id=cut.id,
decoding_graph=decoding_graph,
device=device,
LOG_EPS=LOG_EPSILON,
)
stream.states = model.encoder.get_init_states(device=device)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
feature = fbank(samples)
stream.set_feature(feature)
stream.ground_truth = cut.supervisions[0].text
streams.append(stream)
while len(streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
while len(streams) > 0:
finished_streams = decode_one_chunk(
model=model,
streams=streams,
params=params,
decoding_graph=decoding_graph,
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
streams[i].id,
streams[i].ground_truth.split(),
sp.decode(streams[i].decoding_result()).split(),
)
)
del streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
else:
key = f"beam_size_{params.beam_size}"
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-streaming-decode")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.device = device
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
model=model,
params=params,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
torch.manual_seed(20220810)
main()

View File

@ -0,0 +1,92 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./lstm_transducer_stateless/test_model.py
"""
import os
from pathlib import Path
import torch
from export import (
export_decoder_model_jit_trace,
export_encoder_model_jit_trace,
export_joiner_model_jit_trace,
)
from lstm import stack_states, unstack_states
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = 0
params.exp_dir = Path("exp_test_model")
model = get_transducer_model(params)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
convert_scaled_to_non_scaled(model, inplace=True)
if not os.path.exists(params.exp_dir):
os.path.mkdir(params.exp_dir)
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
print("The model has been successfully exported using jit.trace.")
def test_states_stack_and_unstack():
layer, batch, hidden, cell = 12, 100, 512, 1024
states = (
torch.randn(layer, batch, hidden),
torch.randn(layer, batch, cell),
)
states2 = stack_states(unstack_states(states))
assert torch.allclose(states[0], states2[0])
assert torch.allclose(states[1], states2[1])
def main():
test_model()
test_states_stack_and_unstack()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,257 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./lstm_transducer_stateless/test_scaling_converter.py
"""
import copy
import torch
from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
ScaledLinear,
ScaledLSTM,
)
from scaling_converter import (
convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear,
scaled_lstm_to_lstm,
)
from train import get_params, get_transducer_model
def get_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.encoder_dim = 512
params.rnn_hidden_size = 1024
params.num_encoder_layers = 12
params.aux_layer_period = -1
model = get_transducer_model(params)
return model
def test_scaled_linear_to_linear():
N = 5
in_features = 10
out_features = 20
for bias in [True, False]:
scaled_linear = ScaledLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
)
linear = scaled_linear_to_linear(scaled_linear)
x = torch.rand(N, in_features)
y1 = scaled_linear(x)
y2 = linear(x)
assert torch.allclose(y1, y2)
jit_scaled_linear = torch.jit.script(scaled_linear)
jit_linear = torch.jit.script(linear)
y3 = jit_scaled_linear(x)
y4 = jit_linear(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv1d_to_conv1d():
in_channels = 3
for bias in [True, False]:
scaled_conv1d = ScaledConv1d(
in_channels,
6,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
x = torch.rand(20, in_channels, 10)
y1 = scaled_conv1d(x)
y2 = conv1d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
jit_conv1d = torch.jit.script(conv1d)
y3 = jit_scaled_conv1d(x)
y4 = jit_conv1d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv2d_to_conv2d():
in_channels = 1
for bias in [True, False]:
scaled_conv2d = ScaledConv2d(
in_channels=in_channels,
out_channels=3,
kernel_size=3,
padding=1,
bias=bias,
)
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
x = torch.rand(20, in_channels, 10, 20)
y1 = scaled_conv2d(x)
y2 = conv2d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
jit_conv2d = torch.jit.script(conv2d)
y3 = jit_scaled_conv2d(x)
y4 = jit_conv2d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)
for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)
def test_scaled_lstm_to_lstm():
input_size = 512
batch_size = 20
for bias in [True, False]:
for hidden_size in [512, 1024]:
scaled_lstm = ScaledLSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
bias=bias,
proj_size=0 if hidden_size == input_size else input_size,
)
lstm = scaled_lstm_to_lstm(scaled_lstm)
x = torch.rand(200, batch_size, input_size)
h0 = torch.randn(1, batch_size, input_size)
c0 = torch.randn(1, batch_size, hidden_size)
y1, (h1, c1) = scaled_lstm(x, (h0, c0))
y2, (h2, c2) = lstm(x, (h0, c0))
assert torch.allclose(y1, y2)
assert torch.allclose(h1, h2)
assert torch.allclose(c1, c2)
jit_scaled_lstm = torch.jit.trace(lstm, (x, (h0, c0)))
y3, (h3, c3) = jit_scaled_lstm(x, (h0, c0))
assert torch.allclose(y1, y3)
assert torch.allclose(h1, h3)
assert torch.allclose(c1, c3)
def test_convert_scaled_to_non_scaled():
for inplace in [False, True]:
model = get_model()
model.eval()
orig_model = copy.deepcopy(model)
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
model = orig_model
# test encoder
N = 2
T = 100
vocab_size = model.decoder.vocab_size
x = torch.randn(N, T, 80, dtype=torch.float32)
x_lens = torch.full((N,), x.size(1))
e1, e1_lens, _ = model.encoder(x, x_lens)
e2, e2_lens, _ = converted_model.encoder(x, x_lens)
assert torch.all(torch.eq(e1_lens, e2_lens))
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
# test decoder
U = 50
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y)
d2 = model.decoder(y)
assert torch.allclose(d1, d2)
# test simple projection
lm1 = model.simple_lm_proj(d1)
am1 = model.simple_am_proj(e1)
lm2 = converted_model.simple_lm_proj(d2)
am2 = converted_model.simple_am_proj(e2)
assert torch.allclose(lm1, lm2)
assert torch.allclose(am1, am2)
# test joiner
e = torch.rand(2, 3, 4, 512)
d = torch.rand(2, 3, 4, 512)
j1 = model.joiner(e, d)
j2 = converted_model.joiner(e, d)
assert torch.allclose(j1, j2)
@torch.no_grad()
def main():
test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_scaled_lstm_to_lstm()
test_convert_scaled_to_non_scaled()
if __name__ == "__main__":
torch.manual_seed(20220730)
main()

File diff suppressed because it is too large Load Diff

View File

@ -391,6 +391,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -403,9 +404,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -430,6 +431,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -612,6 +614,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -551,6 +551,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -564,9 +565,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -591,6 +592,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -631,6 +633,8 @@ def main():
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
# we need cut ids to display recognition results.
args.return_cuts = True
params = get_params()
params.update(vars(args))
@ -754,6 +758,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -28,6 +28,7 @@ class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
cut_id: str,
initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
@ -48,6 +49,7 @@ class DecodeStream(object):
assert device == decoding_graph.device
self.params = params
self.cut_id = cut_id
self.LOG_EPS = math.log(1e-10)
self.states = initial_states
@ -102,6 +104,10 @@ class DecodeStream(object):
"""Return True if all the features are processed."""
return self._done
@property
def id(self) -> str:
return self.cut_id
def set_features(
self,
features: torch.Tensor,

View File

@ -15,6 +15,8 @@
# limitations under the License.
from typing import Tuple
import k2
import torch
import torch.nn as nn
@ -66,7 +68,8 @@ class Transducer(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -86,6 +89,10 @@ class Transducer(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.
@ -95,6 +102,7 @@ class Transducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
@ -136,7 +144,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)
@ -163,7 +171,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)
return (simple_loss, pruned_loss)

View File

@ -356,6 +356,7 @@ def decode_dataset(
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
@ -385,6 +386,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
@ -402,6 +404,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)

View File

@ -78,6 +78,7 @@ from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
@ -457,9 +458,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -547,7 +545,34 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
loss = params.simple_loss_scale * simple_loss + pruned_loss
assert loss.requires_grad == is_training
@ -555,6 +580,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
@ -674,13 +703,7 @@ def train_one_epoch(
global_step=params.batch_idx_train,
)
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -728,7 +751,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -738,7 +760,6 @@ def train_one_epoch(
sampler=train_dl.sampler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -893,13 +914,14 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
if params.start_batch <= 0:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)

View File

@ -32,7 +32,7 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask, subsequent_chunk_mask
from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask
class Conformer(EncoderInterface):
@ -155,7 +155,8 @@ class Conformer(EncoderInterface):
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
if not is_jit_tracing():
assert x.size(0) == lengths.max().item()
src_key_padding_mask = make_pad_mask(lengths)
@ -787,6 +788,14 @@ class RelPositionalEncoding(torch.nn.Module):
) -> None:
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
if is_jit_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than
# 10k frames.
#
# TODO(fangjun): Use torch.jit.script() for this module
max_len = 10000
self.d_model = d_model
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
@ -992,7 +1001,7 @@ class RelPositionMultiheadAttention(nn.Module):
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
x: Input tensor (batch, head, time1, 2*time1-1+left_context).
time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
@ -1006,20 +1015,32 @@ class RelPositionMultiheadAttention(nn.Module):
(batch_size, num_heads, time1, n) = x.shape
time2 = time1 + left_context
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
if not is_jit_tracing():
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
if is_jit_tracing():
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(time2)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
x = x.reshape(-1, n)
x = torch.gather(x, dim=1, index=indexes)
x = x.reshape(batch_size, num_heads, time1, time2)
return x
else:
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
def multi_head_attention_forward(
self,
@ -1090,13 +1111,15 @@ class RelPositionMultiheadAttention(nn.Module):
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
if not is_jit_tracing():
assert embed_dim == embed_dim_to_check
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
if not is_jit_tracing():
assert (
head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
@ -1209,7 +1232,7 @@ class RelPositionMultiheadAttention(nn.Module):
src_len = k.size(0)
if key_padding_mask is not None:
if key_padding_mask is not None and not is_jit_tracing():
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
key_padding_mask.size(0), bsz
)
@ -1220,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
if not is_jit_tracing():
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
@ -1255,11 +1280,12 @@ class RelPositionMultiheadAttention(nn.Module):
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if not is_jit_tracing():
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,
src_len,
]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
@ -1318,7 +1344,14 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
if not is_jit_tracing():
assert list(attn_output.size()) == [
bsz * num_heads,
tgt_len,
head_dim,
]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()

View File

@ -574,6 +574,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -587,9 +588,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -614,6 +615,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -777,6 +779,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -19,6 +19,8 @@ import torch.nn as nn
import torch.nn.functional as F
from scaling import ScaledConv1d, ScaledEmbedding
from icefall.utils import is_jit_tracing
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
@ -77,7 +79,12 @@ class Decoder(nn.Module):
# It is to support torch script
self.conv = nn.Identity()
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
def forward(
self,
y: torch.Tensor,
need_pad: bool = True # Annotation should be Union[bool, torch.Tensor]
# but, torch.jit.script does not support Union.
) -> torch.Tensor:
"""
Args:
y:
@ -88,18 +95,24 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
if isinstance(need_pad, torch.Tensor):
# This is for torch.jit.trace(), which cannot handle the case
# when the input argument is not a tensor.
need_pad = bool(need_pad)
y = y.to(torch.int64)
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
if need_pad:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
if not is_jit_tracing():
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)

View File

@ -18,6 +18,8 @@ import torch
import torch.nn as nn
from scaling import ScaledLinear
from icefall.utils import is_jit_tracing
class Joiner(nn.Module):
def __init__(
@ -52,10 +54,10 @@ class Joiner(nn.Module):
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
if not is_jit_tracing():
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(

View File

@ -15,6 +15,8 @@
# limitations under the License.
from typing import Tuple
import k2
import torch
import torch.nn as nn
@ -78,7 +80,8 @@ class Transducer(nn.Module):
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
) -> torch.Tensor:
reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -101,6 +104,10 @@ class Transducer(nn.Module):
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.
@ -110,6 +117,7 @@ class Transducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
@ -155,7 +163,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)
@ -188,7 +196,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)
return (simple_loss, pruned_loss)

View File

@ -1,4 +1,4 @@
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -20,8 +20,11 @@ from itertools import repeat
from typing import Optional, Tuple
import torch
import torch.backends.cudnn.rnn as rnn
import torch.nn as nn
from torch import Tensor
from torch import _VF, Tensor
from icefall.utils import is_jit_tracing
def _ntuple(n):
@ -152,7 +155,8 @@ class BasicNorm(torch.nn.Module):
self.register_buffer("eps", torch.tensor(eps).log().detach())
def forward(self, x: Tensor) -> Tensor:
assert x.shape[self.channel_dim] == self.num_channels
if not is_jit_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True)
+ self.eps.exp()
@ -376,6 +380,156 @@ class ScaledConv2d(nn.Conv2d):
return self._conv_forward(input, self.get_weight())
class ScaledLSTM(nn.LSTM):
# See docs for ScaledLinear.
# This class implements LSTM with scaling mechanism, using `torch._VF.lstm`
# Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py
def __init__(
self,
*args,
initial_scale: float = 1.0,
initial_speed: float = 1.0,
**kwargs
):
if "bidirectional" in kwargs:
assert kwargs["bidirectional"] is False
super(ScaledLSTM, self).__init__(*args, **kwargs)
initial_scale = torch.tensor(initial_scale).log()
self._scales_names = []
self._scales = []
for name in self._flat_weights_names:
scale_name = name + "_scale"
self._scales_names.append(scale_name)
param = nn.Parameter(initial_scale.clone().detach())
setattr(self, scale_name, param)
self._scales.append(param)
self._reset_parameters(
initial_speed
) # Overrides the reset_parameters in base class
def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed
a = (3 ** 0.5) * std
scale = self.hidden_size ** -0.5
v = scale / std
for idx, name in enumerate(self._flat_weights_names):
if "weight" in name:
nn.init.uniform_(self._flat_weights[idx], -a, a)
with torch.no_grad():
self._scales[idx] += torch.tensor(v).log()
elif "bias" in name:
nn.init.constant_(self._flat_weights[idx], 0.0)
def _flatten_parameters(self, flat_weights) -> None:
"""Resets parameter data pointer so that they can use faster code paths.
Right now, this works only if the module is on the GPU and cuDNN is enabled.
Otherwise, it's a no-op.
This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
"""
# Short-circuits if _flat_weights is only partially instantiated
if len(flat_weights) != len(self._flat_weights_names):
return
for w in flat_weights:
if not isinstance(w, Tensor):
return
# Short-circuits if any tensor in flat_weights is not acceptable to cuDNN
# or the tensors in flat_weights are of different dtypes
first_fw = flat_weights[0]
dtype = first_fw.dtype
for fw in flat_weights:
if (
not isinstance(fw.data, Tensor)
or not (fw.data.dtype == dtype)
or not fw.data.is_cuda
or not torch.backends.cudnn.is_acceptable(fw.data)
):
return
# If any parameters alias, we fall back to the slower, copying code path. This is
# a sufficient check, because overlapping parameter buffers that don't completely
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = set(p.data_ptr() for p in flat_weights)
if len(unique_data_ptrs) != len(flat_weights):
return
with torch.cuda.device_of(first_fw):
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
# an inplace operation on self._flat_weights
with torch.no_grad():
if torch._use_cudnn_rnn_flatten_weight():
num_weights = 4 if self.bias else 2
if self.proj_size > 0:
num_weights += 1
torch._cudnn_rnn_flatten_weight(
flat_weights,
num_weights,
self.input_size,
rnn.get_cudnn_mode(self.mode),
self.hidden_size,
self.proj_size,
self.num_layers,
self.batch_first,
bool(self.bidirectional),
)
def _get_flat_weights(self):
"""Get scaled weights, and resets their data pointer."""
flat_weights = []
for idx in range(len(self._flat_weights_names)):
flat_weights.append(
self._flat_weights[idx] * self._scales[idx].exp()
)
self._flatten_parameters(flat_weights)
return flat_weights
def forward(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
):
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
# The change for calling `_VF.lstm()` is:
# self._flat_weights -> self._get_flat_weights()
if hx is None:
h_zeros = torch.zeros(
self.num_layers,
input.size(1),
self.proj_size if self.proj_size > 0 else self.hidden_size,
dtype=input.dtype,
device=input.device,
)
c_zeros = torch.zeros(
self.num_layers,
input.size(1),
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
hx = (h_zeros, c_zeros)
self.check_forward_args(input, hx, None)
result = _VF.lstm(
input,
hx,
self._get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,
self.training,
self.bidirectional,
self.batch_first,
)
output = result[0]
hidden = result[1:]
return output, hidden
class ActivationBalancer(torch.nn.Module):
"""
Modifies the backpropped derivatives of a function to try to encourage, for
@ -423,7 +577,7 @@ class ActivationBalancer(torch.nn.Module):
self.max_abs = max_abs
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting():
if torch.jit.is_scripting() or is_jit_tracing():
return x
else:
return ActivationBalancerFunction.apply(
@ -472,7 +626,7 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or is_jit_tracing():
return x * torch.sigmoid(x - 1.0)
else:
return DoubleSwishFunction.apply(x)
@ -494,9 +648,6 @@ class ScaledEmbedding(nn.Module):
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
@ -505,7 +656,7 @@ class ScaledEmbedding(nn.Module):
initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option,
the start of training. Note: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.
@ -727,8 +878,22 @@ def _test_double_swish_deriv():
torch.autograd.gradcheck(m, x)
def _test_scaled_lstm():
N, L = 2, 30
dim_in, dim_hidden = 10, 20
m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True)
x = torch.randn(L, N, dim_in)
h0 = torch.randn(1, N, dim_hidden)
c0 = torch.randn(1, N, dim_hidden)
y, (h, c) = m(x, (h0, c0))
assert y.shape == (L, N, dim_hidden)
assert h.shape == (1, N, dim_hidden)
assert c.shape == (1, N, dim_hidden)
if __name__ == "__main__":
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()
_test_basic_norm()
_test_double_swish_deriv()
_test_scaled_lstm()

View File

@ -358,6 +358,7 @@ def decode_dataset(
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
@ -388,6 +389,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
@ -405,6 +407,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)

View File

@ -88,7 +88,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -503,9 +509,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -603,7 +606,33 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -623,6 +652,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
@ -724,13 +757,7 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -765,7 +792,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -777,7 +803,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -944,7 +969,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
@ -1004,38 +1029,6 @@ def run(rank, world_size, args):
cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,

View File

@ -422,6 +422,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -434,9 +435,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -610,6 +611,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args)
gigaspeech = GigaSpeech(manifest_dir=args.manifest_dir)

View File

@ -745,6 +745,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -760,9 +761,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -787,6 +788,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -1067,6 +1069,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
asr_datamodule = AsrDataModule(args)
librispeech = LibriSpeech(manifest_dir=args.manifest_dir)

View File

@ -19,14 +19,67 @@
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
It will also generate 3 other files: `encoder_jit_script.pt`,
`decoder_jit_script.pt`, and `joiner_jit_script.pt`.
(2) Export to torchscript model using torch.jit.trace()
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit-trace 1
It will generates 3 files: `encoder_jit_trace.pt`,
`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`.
(3) Export to ONNX format
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--onnx 1
It will generate the following three files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
(4) Export `model.state_dict()`
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless3/decode.py`,
you can do:
@ -42,14 +95,31 @@ you can do:
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp
"""
import argparse
import logging
from pathlib import Path
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
@ -114,6 +184,42 @@ def get_parser():
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate 4 files:
- encoder_jit_script.pt
- decoder_jit_script.pt
- joiner_jit_script.pt
- cpu_jit.pt (which combines the above 3 files)
Check ./jit_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--jit-trace",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.trace.
It will generate 3 files:
- encoder_jit_trace.pt
- decoder_jit_trace.pt
- joiner_jit_trace.pt
Check ./jit_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. Three files will be generated:
- encoder.onnx
- decoder.onnx
- joiner.onnx
Check ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
@ -139,6 +245,299 @@ def get_parser():
return parser
def export_encoder_model_jit_script(
encoder_model: nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.script()
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
script_model = torch.jit.script(encoder_model)
script_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_script(
decoder_model: nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.script()
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
script_model = torch.jit.script(decoder_model)
script_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_script(
joiner_model: nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
script_model = torch.jit.script(joiner_model)
script_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
def export_encoder_model_jit_trace(
encoder_model: nn.Module,
encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
traced_model = torch.jit.trace(encoder_model, (x, x_lens))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
warmup = 1.0
torch.onnx.export(
encoder_model,
(x, x_lens, warmup),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens", "warmup"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported model has two inputs:
- encoder_out: a tensor of shape (N, encoder_out_dim)
- decoder_out: a tensor of shape (N, decoder_out_dim)
and has one output:
- joiner_out: a tensor of shape (N, vocab_size)
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
project_input = True
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out", "project_input"],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
def export_all_in_one_onnx(
encoder_filename: str,
decoder_filename: str,
joiner_filename: str,
all_in_one_filename: str,
):
encoder_onnx = onnx.load(encoder_filename)
decoder_onnx = onnx.load(decoder_filename)
joiner_onnx = onnx.load(joiner_filename)
encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/")
decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/")
joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/")
combined_model = onnx.compose.merge_models(
encoder_onnx, decoder_onnx, io_map={}
)
combined_model = onnx.compose.merge_models(
combined_model, joiner_onnx, io_map={}
)
onnx.save(combined_model, all_in_one_filename)
logging.info(f"Saved to {all_in_one_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
@ -165,7 +564,7 @@ def main():
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model = get_transducer_model(params, enable_giga=False)
model.to(device)
@ -185,7 +584,9 @@ def main():
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
@ -196,14 +597,48 @@ def main():
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=False
)
model.to("cpu")
model.eval()
if params.jit:
if params.onnx is True:
convert_scaled_to_non_scaled(model, inplace=True)
opset_version = 11
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
all_in_one_filename = params.exp_dir / "all_in_one.onnx"
export_all_in_one_onnx(
encoder_filename,
decoder_filename,
joiner_filename,
all_in_one_filename,
)
elif params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.script()")
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
@ -214,8 +649,30 @@ def main():
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
# Also export encoder/decoder/joiner separately
encoder_filename = params.exp_dir / "encoder_jit_script.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_script.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_script.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
elif params.jit_trace is True:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename)
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
else:
logging.info("Not using torch.jit.script")
logging.info("Not using torchscript")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"

View File

@ -0,0 +1,338 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, either exported by `torch.jit.trace()`
or by `torch.jit.script()`, and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit-trace 1
or
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit 1
Usage of this script:
./pruned_transducer_stateless3/jit_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_trace.pt \
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_trace.pt \
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_trace.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
or
./pruned_transducer_stateless3/jit_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder_jit_script.pt \
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder_jit_script.pt \
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner_jit_script.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = encoder_out.device
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
).squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = joiner(
current_encoder_out,
decoder_out,
)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = decoder(
decoder_input,
need_pad=torch.tensor([False]),
)
decoder_out = decoder_out.squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
encoder = torch.jit.load(args.encoder_model_filename)
decoder = torch.jit.load(args.decoder_model_filename)
joiner = torch.jit.load(args.joiner_model_filename)
encoder.eval()
decoder.eval()
joiner.eval()
encoder.to(device)
decoder.to(device)
joiner.to(device)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = encoder(
x=features,
x_lens=feature_lengths,
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -15,7 +15,7 @@
# limitations under the License.
from typing import Optional
from typing import Optional, Tuple
import k2
import torch
@ -105,7 +105,8 @@ class Transducer(nn.Module):
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
) -> torch.Tensor:
reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -131,6 +132,10 @@ class Transducer(nn.Module):
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
Returns:
Return the transducer loss.
@ -140,6 +145,7 @@ class Transducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
@ -196,7 +202,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)
@ -229,7 +235,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)
return (simple_loss, pruned_loss)

View File

@ -0,0 +1,199 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
"""
import argparse
import logging
import onnxruntime as ort
import torch
ort.set_default_logger_severity(3)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-encoder-filename",
required=True,
type=str,
help="Path to the onnx encoder model",
)
parser.add_argument(
"--onnx-decoder-filename",
required=True,
type=str,
help="Path to the onnx decoder model",
)
parser.add_argument(
"--onnx-joiner-filename",
required=True,
type=str,
help="Path to the onnx joiner model",
)
return parser
def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].name == "x"
assert encoder_inputs[1].name == "x_lens"
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
for N in [1, 5]:
for T in [12, 25]:
print("N, T", N, T)
x = torch.rand(N, T, 80, dtype=torch.float32)
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
x_lens[0] = T
encoder_inputs = {
"x": x.numpy(),
"x_lens": x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run(
["encoder_out", "encoder_out_lens"],
encoder_inputs,
)
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max()
)
def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].name == "y"
assert decoder_inputs[0].shape == ["N", 2]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {"y": y.numpy()}
decoder_out = decoder_session.run(
["decoder_out"],
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
torch_decoder_out = model.decoder(y, need_pad=False)
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
(decoder_out - torch_decoder_out).abs().max()
)
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].name == "encoder_out"
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[1].name == "decoder_out"
assert joiner_inputs[1].shape == ["N", 512]
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
joiner_inputs = {
"encoder_out": encoder_out.numpy(),
"decoder_out": decoder_out.numpy(),
}
joiner_out = joiner_session.run(["logit"], joiner_inputs)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(
encoder_out,
decoder_out,
project_input=True,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
model = torch.jit.load(args.jit_filename)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
logging.info("Test encoder")
encoder_session = ort.InferenceSession(
args.onnx_encoder_filename,
sess_options=options,
)
test_encoder(model, encoder_session)
logging.info("Test decoder")
decoder_session = ort.InferenceSession(
args.onnx_decoder_filename,
sess_options=options,
)
test_decoder(model, decoder_session)
logging.info("Test joiner")
joiner_session = ort.InferenceSession(
args.onnx_joiner_filename,
sess_options=options,
)
test_joiner(model, joiner_session)
logging.info("Finished checking ONNX models")
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,284 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corporation (Author: Yunus Emre Ozkose)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks that exported onnx models produce the same output
with the given torchscript model for the same input.
"""
import argparse
import logging
import os
import onnx
import onnx_graphsurgeon as gs
import onnxruntime
import onnxruntime as ort
import torch
ort.set_default_logger_severity(3)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--jit-filename",
required=True,
type=str,
help="Path to the torchscript model",
)
parser.add_argument(
"--onnx-all-in-one-filename",
required=True,
type=str,
help="Path to the onnx all in one model",
)
return parser
def test_encoder(
model: torch.jit.ScriptModule,
encoder_session: ort.InferenceSession,
):
encoder_inputs = encoder_session.get_inputs()
assert encoder_inputs[0].shape == ["N", "T", 80]
assert encoder_inputs[1].shape == ["N"]
encoder_input_names = [i.name for i in encoder_inputs]
encoder_output_names = [i.name for i in encoder_session.get_outputs()]
for N in [1, 5]:
for T in [12, 25]:
print("N, T", N, T)
x = torch.rand(N, T, 80, dtype=torch.float32)
x_lens = torch.randint(low=10, high=T + 1, size=(N,))
x_lens[0] = T
encoder_inputs = {
encoder_input_names[0]: x.numpy(),
encoder_input_names[1]: x_lens.numpy(),
}
encoder_out, encoder_out_lens = encoder_session.run(
[encoder_output_names[1], encoder_output_names[0]],
encoder_inputs,
)
torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens)
encoder_out = torch.from_numpy(encoder_out)
assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), (
(encoder_out - torch_encoder_out).abs().max()
)
def test_decoder(
model: torch.jit.ScriptModule,
decoder_session: ort.InferenceSession,
):
decoder_inputs = decoder_session.get_inputs()
assert decoder_inputs[0].shape == ["N", 2]
decoder_input_names = [i.name for i in decoder_inputs]
decoder_output_names = [i.name for i in decoder_session.get_outputs()]
for N in [1, 5, 10]:
y = torch.randint(low=1, high=500, size=(10, 2))
decoder_inputs = {decoder_input_names[0]: y.numpy()}
decoder_out = decoder_session.run(
[decoder_output_names[0]],
decoder_inputs,
)[0]
decoder_out = torch.from_numpy(decoder_out)
torch_decoder_out = model.decoder(y, need_pad=False)
assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), (
(decoder_out - torch_decoder_out).abs().max()
)
def test_joiner(
model: torch.jit.ScriptModule,
joiner_session: ort.InferenceSession,
):
joiner_inputs = joiner_session.get_inputs()
assert joiner_inputs[0].shape == ["N", 512]
assert joiner_inputs[1].shape == ["N", 512]
joiner_input_names = [i.name for i in joiner_inputs]
joiner_output_names = [i.name for i in joiner_session.get_outputs()]
for N in [1, 5, 10]:
encoder_out = torch.rand(N, 512)
decoder_out = torch.rand(N, 512)
joiner_inputs = {
joiner_input_names[0]: encoder_out.numpy(),
joiner_input_names[1]: decoder_out.numpy(),
}
joiner_out = joiner_session.run(
[joiner_output_names[0]], joiner_inputs
)[0]
joiner_out = torch.from_numpy(joiner_out)
torch_joiner_out = model.joiner(
encoder_out,
decoder_out,
project_input=True,
)
assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), (
(joiner_out - torch_joiner_out).abs().max()
)
def extract_sub_model(
onnx_graph: onnx.ModelProto,
input_op_names: list,
output_op_names: list,
non_verbose=False,
):
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort()
# Extraction of input OP and output OP
graph_node_inputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_input in graph_nodes.inputs
if graph_nodes_input.name in input_op_names
]
graph_node_outputs = [
graph_nodes
for graph_nodes in graph.nodes
for graph_nodes_output in graph_nodes.outputs
if graph_nodes_output.name in output_op_names
]
# Init graph INPUT/OUTPUT
graph.inputs.clear()
graph.outputs.clear()
# Update graph INPUT/OUTPUT
graph.inputs = [
graph_node_input
for graph_node in graph_node_inputs
for graph_node_input in graph_node.inputs
if graph_node_input.shape
]
graph.outputs = [
graph_node_output
for graph_node in graph_node_outputs
for graph_node_output in graph_node.outputs
]
# Cleanup
graph.cleanup().toposort()
# Shape Estimation
extracted_graph = None
try:
extracted_graph = onnx.shape_inference.infer_shapes(
gs.export_onnx(graph)
)
except Exception:
extracted_graph = gs.export_onnx(graph)
if not non_verbose:
print(
"WARNING: "
+ "The input shape of the next OP does not match the output shape. "
+ "Be sure to open the .onnx file to verify the certainty of the geometry."
)
return extracted_graph
def extract_encoder(onnx_model: onnx.ModelProto):
encoder_ = extract_sub_model(
onnx_model,
["encoder/x", "encoder/x_lens"],
["encoder/encoder_out", "encoder/encoder_out_lens"],
False,
)
onnx.save(encoder_, "tmp_encoder.onnx")
onnx.checker.check_model(encoder_)
sess = onnxruntime.InferenceSession("tmp_encoder.onnx")
os.remove("tmp_encoder.onnx")
return sess
def extract_decoder(onnx_model: onnx.ModelProto):
decoder_ = extract_sub_model(
onnx_model, ["decoder/y"], ["decoder/decoder_out"], False
)
onnx.save(decoder_, "tmp_decoder.onnx")
onnx.checker.check_model(decoder_)
sess = onnxruntime.InferenceSession("tmp_decoder.onnx")
os.remove("tmp_decoder.onnx")
return sess
def extract_joiner(onnx_model: onnx.ModelProto):
joiner_ = extract_sub_model(
onnx_model,
["joiner/encoder_out", "joiner/decoder_out"],
["joiner/logit"],
False,
)
onnx.save(joiner_, "tmp_joiner.onnx")
onnx.checker.check_model(joiner_)
sess = onnxruntime.InferenceSession("tmp_joiner.onnx")
os.remove("tmp_joiner.onnx")
return sess
@torch.no_grad()
def main():
args = get_parser().parse_args()
logging.info(vars(args))
model = torch.jit.load(args.jit_filename)
onnx_model = onnx.load(args.onnx_all_in_one_filename)
options = ort.SessionOptions()
options.inter_op_num_threads = 1
options.intra_op_num_threads = 1
logging.info("Test encoder")
encoder_session = extract_encoder(onnx_model)
test_encoder(model, encoder_session)
logging.info("Test decoder")
decoder_session = extract_decoder(onnx_model)
test_decoder(model, decoder_session)
logging.info("Test joiner")
joiner_session = extract_joiner(onnx_model)
test_joiner(model, joiner_session)
logging.info("Finished checking ONNX models")
if __name__ == "__main__":
torch.manual_seed(20220727)
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,337 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX models and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--onnx 1
Usage of this script:
./pruned_transducer_stateless3/jit_trace_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless3/exp/encoder.onnx \
--decoder-model-filename ./pruned_transducer_stateless3/exp/decoder.onnx \
--joiner-model-filename ./pruned_transducer_stateless3/exp/joiner.onnx \
--bpe-model ./data/lang_bpe_500/bpe.model \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import numpy as np
import onnxruntime as ort
import sentencepiece as spm
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="Context size of the decoder model",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: ort.InferenceSession,
joiner: ort.InferenceSession,
encoder_out: np.ndarray,
encoder_out_lens: np.ndarray,
context_size: int,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
decoder:
The decoder model.
joiner:
The joiner model.
encoder_out:
A 3-D tensor of shape (N, T, C)
encoder_out_lens:
A 1-D tensor of shape (N,).
context_size:
The context size of the decoder model.
Returns:
Return the decoded results for each utterance.
"""
encoder_out = torch.from_numpy(encoder_out)
encoder_out_lens = torch.from_numpy(encoder_out_lens)
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = 0 # hard-code to 0
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input_nodes = decoder.get_inputs()
decoder_output_nodes = decoder.get_outputs()
joiner_input_nodes = joiner.get_inputs()
joiner_output_nodes = joiner.get_outputs()
decoder_input = torch.tensor(
hyps,
dtype=torch.int64,
) # (N, context_size)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = packed_encoder_out.data[start:end]
current_encoder_out = current_encoder_out
# current_encoder_out's shape: (batch_size, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = joiner.run(
[joiner_output_nodes[0].name],
{
joiner_input_nodes[0].name: current_encoder_out.numpy(),
joiner_input_nodes[1].name: decoder_out,
},
)[0]
logits = torch.from_numpy(logits)
# logits'shape (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
dtype=torch.int64,
)
decoder_out = decoder.run(
[decoder_output_nodes[0].name],
{
decoder_input_nodes[0].name: decoder_input.numpy(),
},
)[0].squeeze(1)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=session_opts,
)
decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=session_opts,
)
joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=session_opts,
)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = args.sample_rate
opts.mel_opts.num_bins = 80
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {args.sound_files}")
waves = read_sound_files(
filenames=args.sound_files,
expected_sample_rate=args.sample_rate,
)
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features,
batch_first=True,
padding_value=math.log(1e-10),
)
feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64)
encoder_input_nodes = encoder.get_inputs()
encoder_out_nodes = encoder.get_outputs()
encoder_out, encoder_out_lens = encoder.run(
[encoder_out_nodes[0].name, encoder_out_nodes[1].name],
{
encoder_input_nodes[0].name: features.numpy(),
encoder_input_nodes[1].name: feature_lengths.numpy(),
},
)
hyps = greedy_search(
decoder=decoder,
joiner=joiner,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
context_size=args.context_size,
)
s = "\n"
for filename, hyp in zip(args.sound_files, hyps):
words = sp.decode(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -15,7 +15,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless3/export.py \
--exp-dir ./pruned_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
Usage of this script:
(1) greedy search
./pruned_transducer_stateless3/pretrained.py \

View File

@ -0,0 +1,269 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.
The scaled version are required only in the training time. It simplifies our
life by converting them to their non-scaled version during inference.
"""
import copy
import re
from typing import List
import torch
import torch.nn as nn
from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
ScaledLinear,
ScaledLSTM,
)
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
"""Convert an instance of ScaledLinear to nn.Linear.
Args:
scaled_linear:
The layer to be converted.
Returns:
Return a linear layer. It satisfies:
scaled_linear(x) == linear(x)
for any given input tensor `x`.
"""
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
weight = scaled_linear.get_weight()
bias = scaled_linear.get_bias()
has_bias = bias is not None
linear = torch.nn.Linear(
in_features=scaled_linear.in_features,
out_features=scaled_linear.out_features,
bias=True, # otherwise, it throws errors when converting to PNNX format.
# device=weight.device, # Pytorch version before v1.9.0 does not has
# this argument. Comment out for now, we will
# see if it will raise error for versions
# after v1.9.0
)
linear.weight.data.copy_(weight)
if has_bias:
linear.bias.data.copy_(bias)
else:
linear.bias.data.zero_()
return linear
def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d:
"""Convert an instance of ScaledConv1d to nn.Conv1d.
Args:
scaled_conv1d:
The layer to be converted.
Returns:
Return an instance of nn.Conv1d that has the same `forward()` behavior
of the given `scaled_conv1d`.
"""
assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d)
weight = scaled_conv1d.get_weight()
bias = scaled_conv1d.get_bias()
has_bias = bias is not None
conv1d = nn.Conv1d(
in_channels=scaled_conv1d.in_channels,
out_channels=scaled_conv1d.out_channels,
kernel_size=scaled_conv1d.kernel_size,
stride=scaled_conv1d.stride,
padding=scaled_conv1d.padding,
dilation=scaled_conv1d.dilation,
groups=scaled_conv1d.groups,
bias=scaled_conv1d.bias is not None,
padding_mode=scaled_conv1d.padding_mode,
)
conv1d.weight.data.copy_(weight)
if has_bias:
conv1d.bias.data.copy_(bias)
return conv1d
def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
"""Convert an instance of ScaledConv2d to nn.Conv2d.
Args:
scaled_conv2d:
The layer to be converted.
Returns:
Return an instance of nn.Conv2d that has the same `forward()` behavior
of the given `scaled_conv2d`.
"""
assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d)
weight = scaled_conv2d.get_weight()
bias = scaled_conv2d.get_bias()
has_bias = bias is not None
conv2d = nn.Conv2d(
in_channels=scaled_conv2d.in_channels,
out_channels=scaled_conv2d.out_channels,
kernel_size=scaled_conv2d.kernel_size,
stride=scaled_conv2d.stride,
padding=scaled_conv2d.padding,
dilation=scaled_conv2d.dilation,
groups=scaled_conv2d.groups,
bias=scaled_conv2d.bias is not None,
padding_mode=scaled_conv2d.padding_mode,
)
conv2d.weight.data.copy_(weight)
if has_bias:
conv2d.bias.data.copy_(bias)
return conv2d
def scaled_embedding_to_embedding(
scaled_embedding: ScaledEmbedding,
) -> nn.Embedding:
"""Convert an instance of ScaledEmbedding to nn.Embedding.
Args:
scaled_embedding:
The layer to be converted.
Returns:
Return an instance of nn.Embedding that has the same `forward()` behavior
of the given `scaled_embedding`.
"""
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
embedding = nn.Embedding(
num_embeddings=scaled_embedding.num_embeddings,
embedding_dim=scaled_embedding.embedding_dim,
padding_idx=scaled_embedding.padding_idx,
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
sparse=scaled_embedding.sparse,
)
weight = scaled_embedding.weight
scale = scaled_embedding.scale
embedding.weight.data.copy_(weight * scale.exp())
return embedding
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
"""Convert an instance of ScaledLSTM to nn.LSTM.
Args:
scaled_lstm:
The layer to be converted.
Returns:
Return an instance of nn.LSTM that has the same `forward()` behavior
of the given `scaled_lstm`.
"""
assert isinstance(scaled_lstm, ScaledLSTM), type(scaled_lstm)
lstm = nn.LSTM(
input_size=scaled_lstm.input_size,
hidden_size=scaled_lstm.hidden_size,
num_layers=scaled_lstm.num_layers,
bias=scaled_lstm.bias,
batch_first=scaled_lstm.batch_first,
dropout=scaled_lstm.dropout,
bidirectional=scaled_lstm.bidirectional,
proj_size=scaled_lstm.proj_size,
)
assert lstm._flat_weights_names == scaled_lstm._flat_weights_names
for idx in range(len(scaled_lstm._flat_weights_names)):
scaled_weight = (
scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp()
)
lstm._flat_weights[idx].data.copy_(scaled_weight)
return lstm
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
and `nn.Conv2d`.
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
excluded_patterns = r"self_attn\.(in|out)_proj"
p = re.compile(excluded_patterns)
d = {}
for name, m in model.named_modules():
if isinstance(m, ScaledLinear):
if p.search(name) is not None:
continue
d[name] = scaled_linear_to_linear(m)
elif isinstance(m, ScaledConv1d):
d[name] = scaled_conv1d_to_conv1d(m)
elif isinstance(m, ScaledConv2d):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
elif isinstance(m, ScaledLSTM):
d[name] = scaled_lstm_to_lstm(m)
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -359,6 +359,7 @@ def decode_dataset(
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
@ -389,6 +390,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
@ -406,6 +408,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)

View File

@ -0,0 +1,218 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless3/test_scaling_converter.py
"""
import copy
import torch
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
from scaling_converter import (
convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear,
)
from train import get_params, get_transducer_model
def get_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.dynamic_chunk_training = False
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = False
model = get_transducer_model(params, enable_giga=False)
return model
def test_scaled_linear_to_linear():
N = 5
in_features = 10
out_features = 20
for bias in [True, False]:
scaled_linear = ScaledLinear(
in_features=in_features,
out_features=out_features,
bias=bias,
)
linear = scaled_linear_to_linear(scaled_linear)
x = torch.rand(N, in_features)
y1 = scaled_linear(x)
y2 = linear(x)
assert torch.allclose(y1, y2)
jit_scaled_linear = torch.jit.script(scaled_linear)
jit_linear = torch.jit.script(linear)
y3 = jit_scaled_linear(x)
y4 = jit_linear(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv1d_to_conv1d():
in_channels = 3
for bias in [True, False]:
scaled_conv1d = ScaledConv1d(
in_channels,
6,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
conv1d = scaled_conv1d_to_conv1d(scaled_conv1d)
x = torch.rand(20, in_channels, 10)
y1 = scaled_conv1d(x)
y2 = conv1d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv1d = torch.jit.script(scaled_conv1d)
jit_conv1d = torch.jit.script(conv1d)
y3 = jit_scaled_conv1d(x)
y4 = jit_conv1d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_conv2d_to_conv2d():
in_channels = 1
for bias in [True, False]:
scaled_conv2d = ScaledConv2d(
in_channels=in_channels,
out_channels=3,
kernel_size=3,
padding=1,
bias=bias,
)
conv2d = scaled_conv2d_to_conv2d(scaled_conv2d)
x = torch.rand(20, in_channels, 10, 20)
y1 = scaled_conv2d(x)
y2 = conv2d(x)
assert torch.allclose(y1, y2)
jit_scaled_conv2d = torch.jit.script(scaled_conv2d)
jit_conv2d = torch.jit.script(conv2d)
y3 = jit_scaled_conv2d(x)
y4 = jit_conv2d(x)
assert torch.allclose(y3, y4)
assert torch.allclose(y1, y4)
def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)
for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)
def test_convert_scaled_to_non_scaled():
for inplace in [False, True]:
model = get_model()
model.eval()
orig_model = copy.deepcopy(model)
converted_model = convert_scaled_to_non_scaled(model, inplace=inplace)
model = orig_model
# test encoder
N = 2
T = 100
vocab_size = model.decoder.vocab_size
x = torch.randn(N, T, 80, dtype=torch.float32)
x_lens = torch.full((N,), x.size(1))
e1, e1_lens = model.encoder(x, x_lens)
e2, e2_lens = converted_model.encoder(x, x_lens)
assert torch.all(torch.eq(e1_lens, e2_lens))
assert torch.allclose(e1, e2), (e1 - e2).abs().max()
# test decoder
U = 50
y = torch.randint(low=1, high=vocab_size - 1, size=(N, U))
d1 = model.decoder(y)
d2 = model.decoder(y)
assert torch.allclose(d1, d2)
# test simple projection
lm1 = model.simple_lm_proj(d1)
am1 = model.simple_am_proj(e1)
lm2 = converted_model.simple_lm_proj(d2)
am2 = converted_model.simple_am_proj(e2)
assert torch.allclose(lm1, lm2)
assert torch.allclose(am1, am2)
# test joiner
e = torch.rand(2, 3, 4, 512)
d = torch.rand(2, 3, 4, 512)
j1 = model.joiner(e, d)
j2 = converted_model.joiner(e, d)
assert torch.allclose(j1, j2)
@torch.no_grad()
def main():
test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_convert_scaled_to_non_scaled()
if __name__ == "__main__":
torch.manual_seed(20220730)
main()

View File

@ -84,7 +84,13 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -436,13 +442,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(
params: AttributeDict,
enable_giga: bool = True,
) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
if enable_giga:
logging.info("Use giga")
decoder_giga = get_decoder_model(params)
joiner_giga = get_joiner_model(params)
else:
logging.info("Disable giga")
decoder_giga = None
joiner_giga = None
model = Transducer(
encoder=encoder,
@ -628,7 +643,33 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -648,6 +689,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
@ -1049,14 +1094,15 @@ def run(rank, world_size, args):
# It's time consuming to include `giga_train_dl` here
# for dl in [train_dl, giga_train_dl]:
for dl in [train_dl]:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
if params.start_batch <= 0:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=dl,
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:

View File

@ -578,6 +578,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -591,9 +592,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -618,6 +619,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -831,6 +833,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -371,6 +371,7 @@ def decode_dataset(
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
@ -401,6 +402,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
@ -418,6 +420,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)

View File

@ -93,7 +93,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -525,9 +531,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -633,7 +636,33 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -653,6 +682,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
@ -757,13 +790,7 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -805,7 +832,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -818,7 +844,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -993,7 +1018,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,

View File

@ -564,6 +564,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -577,9 +578,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -604,6 +605,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -817,6 +819,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -371,6 +371,7 @@ def decode_dataset(
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
@ -401,6 +402,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
@ -418,6 +420,7 @@ def decode_dataset(
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)

View File

@ -81,7 +81,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -550,9 +556,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -658,7 +661,34 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -678,6 +708,10 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
@ -782,13 +816,7 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -834,7 +862,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -847,7 +874,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -1025,7 +1051,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
@ -1087,38 +1113,6 @@ def run(rank, world_size, args):
cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,

View File

@ -387,6 +387,7 @@ def decode_dataset(
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
@ -399,9 +400,9 @@ def decode_dataset(
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -426,6 +427,7 @@ def save_results(
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
@ -608,6 +610,8 @@ def main():
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()

View File

@ -15,16 +15,17 @@
# limitations under the License.
from typing import Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from multi_quantization.prediction import JointCodebookLoss
from scaling import ScaledLinear
from icefall.utils import add_sos
from multi_quantization.prediction import JointCodebookLoss
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
@ -89,8 +90,9 @@ class Transducer(nn.Module):
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
codebook_indexes: torch.Tensor = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -113,6 +115,10 @@ class Transducer(nn.Module):
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns:
@ -124,6 +130,7 @@ class Transducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert reduction in ("sum", "none"), reduction
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
@ -184,7 +191,7 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
reduction=reduction,
return_grad=True,
)
@ -217,7 +224,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
reduction=reduction,
)
return (simple_loss, pruned_loss, codebook_loss)

View File

@ -93,7 +93,13 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
MetricsTracker,
display_and_save_batch,
setup_logger,
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
@ -507,9 +513,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -634,8 +637,34 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
codebook_indexes=codebook_indexes,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
if not torch.all(is_finite):
logging.info(
"Not all losses are finite!\n"
f"simple_loss: {simple_loss}\n"
f"pruned_loss: {pruned_loss}"
)
display_and_save_batch(batch, params=params, sp=sp)
simple_loss = simple_loss[simple_loss_is_finite]
pruned_loss = pruned_loss[pruned_loss_is_finite]
# If the batch contains more than 10 utterances AND
# if either all simple_loss or pruned_loss is inf or nan,
# we stop the training process by raising an exception
if torch.all(~simple_loss_is_finite) or torch.all(
~pruned_loss_is_finite
):
raise ValueError(
"There are too many utterances in this batch "
"leading to inf or nan losses."
)
simple_loss = simple_loss.sum()
pruned_loss = pruned_loss.sum()
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
@ -657,6 +686,10 @@ def compute_loss(
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# info["frames"] is an approximate number for two reasons:
# (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
# (2) If some utterances in the batch lead to inf/nan loss, they
# are filtered out.
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
@ -763,13 +796,7 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -811,7 +838,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -824,7 +850,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
@ -999,7 +1024,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if params.start_batch <= 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,

Some files were not shown because too many files have changed in this diff Show More