[WIP] Pruned-transducer-stateless5-for-WenetSpeech (offline and streaming) (#447)

* pruned-rnnt5-for-wenetspeech

* style check

* style check

* add streaming conformer

* add streaming decode

* changes codes for fast_beam_search and export cpu jit

* add modified-beam-search for streaming decoding

* add modified-beam-search for streaming decoding

* change for streaming_beam_search.py

* add README.md and RESULTS.md

* change for style_check.yml

* do some changes

* do some changes for export.py

* add some decode commands for usage

* add streaming results on README.md
This commit is contained in:
Mingshuang Luo 2022-07-28 12:54:27 +08:00 committed by GitHub
parent 385645d533
commit f26b62ac00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 5311 additions and 9 deletions

View File

@ -29,7 +29,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04, macos-10.15]
os: [ubuntu-18.04, macos-latest]
python-version: [3.7, 3.9]
fail-fast: false

View File

@ -250,9 +250,9 @@ We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless mod
### WenetSpeech
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2].
We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless2] and [Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][WenetSpeech_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
#### Pruned stateless RNN-T_2: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset, offline ASR)
| | Dev | Test-Net | Test-Meeting |
|----------------------|-------|----------|--------------|
@ -260,7 +260,15 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
| fast beam search | 7.94 | 8.74 | 13.80 |
| modified beam search | 7.76 | 8.71 | 13.41 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
#### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset)
**Streaming**:
| | Dev | Test-Net | Test-Meeting |
|----------------------|-------|----------|--------------|
| greedy_search | 8.78 | 10.12 | 16.16 |
| modified_beam_search | 8.53| 9.95 | 15.81 |
| fast_beam_search| 9.01 | 10.47 | 16.28 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless2 model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1EV4e1CHa1GZgEF-bZgizqI9RyFFehIiN?usp=sharing)
### Alimeeting
@ -333,6 +341,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[GigaSpeech_pruned_transducer_stateless2]: egs/gigaspeech/ASR/pruned_transducer_stateless2
[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2
[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
[WenetSpeech_pruned_transducer_stateless5]: egs/wenetspeech/ASR/pruned_transducer_stateless5
[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5

View File

@ -13,6 +13,7 @@ The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -1,12 +1,84 @@
## Results
### WenetSpeech char-based training results (offline and streaming) (Pruned Transducer 5)
#### 2022-07-22
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/447.
When training with the L subset, the CERs are
**Offline**:
|decoding-method| epoch | avg | use-averaged-model | DEV | TEST-NET | TEST-MEETING|
|-- | -- | -- | -- | -- | -- | --|
|greedy_search | 4 | 1 | True | 8.22 | 9.03 | 14.54|
|modified_beam_search | 4 | 1 | True | **8.17** | **9.04** | **14.44**|
|fast_beam_search | 4 | 1 | True | 8.29 | 9.00 | 14.93|
The offline training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless5/train.py \
--lang-dir data/lang_char \
--exp-dir pruned_transducer_stateless5/exp_L_offline \
--world-size 8 \
--num-epochs 15 \
--start-epoch 2 \
--max-duration 120 \
--valid-interval 3000 \
--model-warm-step 3000 \
--save-every-n 8000 \
--average-period 1000 \
--training-subset L
```
The tensorboard training log can be found at https://tensorboard.dev/experiment/SvnN2jfyTB2Hjqu22Z7ZoQ/#scalars .
A pre-trained offline model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline>
**Streaming**:
|decoding-method| epoch | avg | use-averaged-model | DEV | TEST-NET | TEST-MEETING|
|--|--|--|--|--|--|--|
| greedy_search | 7| 1| True | 8.78 | 10.12 | 16.16 |
| modified_beam_search | 7| 1| True| **8.53**| **9.95** | **15.81** |
| fast_beam_search | 7 | 1| True | 9.01 | 10.47 | 16.28 |
The streaming training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless5/train.py \
--lang-dir data/lang_char \
--exp-dir pruned_transducer_stateless5/exp_L_streaming \
--world-size 8 \
--num-epochs 15 \
--start-epoch 1 \
--max-duration 140 \
--valid-interval 3000 \
--model-warm-step 3000 \
--save-every-n 8000 \
--average-period 1000 \
--training-subset L \
--dynamic-chunk-training True \
--causal-convolution True \
--short-chunk-size 25 \
--num-left-chunks 4
```
The tensorboard training log can be found at https://tensorboard.dev/experiment/E2NXPVflSOKWepzJ1a1uDQ/#scalars .
A pre-trained offline model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_streaming>
### WenetSpeech char-based training results (Pruned Transducer 2)
#### 2022-05-19
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/349.
When training with the L subset, the WERs are
When training with the L subset, the CERs are
| | dev | test-net | test-meeting | comment |
|------------------------------------|-------|----------|--------------|------------------------------------------|
@ -72,7 +144,7 @@ avg=2
--max-states 8
```
When training with the M subset, the WERs are
When training with the M subset, the CERs are
| | dev | test-net | test-meeting | comment |
|------------------------------------|--------|-----------|---------------|-------------------------------------------|
@ -81,7 +153,7 @@ When training with the M subset, the WERs are
| fast beam search (set as default) | 10.18 | 11.10 | 19.32 | --epoch 29, --avg 11, --max-duration 1500 |
When training with the S subset, the WERs are
When training with the S subset, the CERs are
| | dev | test-net | test-meeting | comment |
|------------------------------------|--------|-----------|---------------|-------------------------------------------|

View File

@ -348,7 +348,6 @@ def get_params() -> AttributeDict:
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
@ -376,7 +375,6 @@ def get_params() -> AttributeDict:
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam
"env_info": get_env_info(),
}
)

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/beam_search.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,778 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
When training with the L subset, the offline usage:
(1) greedy search
./pruned_transducer_stateless5/decode.py \
--epoch 4 \
--avg 1 \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
--lang-dir data/lang_char \
--max-duration 100 \
--decoding-method greedy_search
(2) modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch 4 \
--avg 1 \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
--lang-dir data/lang_char \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(3) fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch 4 \
--avg 1 \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
--lang-dir data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
When training with the L subset, the streaming usage:
(1) greedy search
./pruned_transducer_stateless5/decode.py \
--lang-dir data/lang_char \
--exp-dir pruned_transducer_stateless5/exp_L_streaming \
--use-averaged-model True \
--max-duration 600 \
--epoch 7 \
--avg 1 \
--decoding-method greedy_search \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64
(2) modified beam search
./pruned_transducer_stateless5/decode.py \
--lang-dir data/lang_char \
--exp-dir pruned_transducer_stateless5/exp_L_streaming \
--use-averaged-model True \
--max-duration 600 \
--epoch 7 \
--avg 1 \
--decoding-method modified_beam_search \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64
(3) fast beam search
./pruned_transducer_stateless5/decode.py \
--lang-dir data/lang_char \
--exp-dir pruned_transducer_stateless5/exp_L_streaming \
--use-averaged-model True \
--max-duration 600 \
--epoch 7 \
--avg 1 \
--decoding-method fast_beam_search \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text)) for text in texts]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
this_batch.append((ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev"
test_net = "test_net"
test_meeting = "test_meeting"
if not os.path.exists(f"{dev}/shared-0.tar"):
os.makedirs(dev)
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_net}/shared-0.tar"):
os.makedirs(test_net)
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)
dev_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meeting_shards = [
str(path)
for path in sorted(
glob.glob(os.path.join(test_meeting, "shared-*.tar"))
)
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,147 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple
import k2
import torch
from beam_search import Hypothesis, HypothesisList
from icefall.utils import AttributeDict
class DecodeStream(object):
def __init__(
self,
params: AttributeDict,
initial_states: List[torch.Tensor],
decoding_graph: Optional[k2.Fsa] = None,
device: torch.device = torch.device("cpu"),
) -> None:
"""
Args:
initial_states:
Initial decode states of the model, e.g. the return value of
`get_init_state` in conformer.py
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Used only when decoding_method is fast_beam_search.
device:
The device to run this stream.
"""
if params.decoding_method == "fast_beam_search":
assert decoding_graph is not None
assert device == decoding_graph.device
self.params = params
self.LOG_EPS = math.log(1e-10)
self.states = initial_states
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
self.num_frames: int = 0
# how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_processed_frames: int = 0
self._done: bool = False
# The transcript of current utterance.
self.ground_truth: str = ""
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
# how many frames have been processed, after subsampling (i.e. a
# cumulative sum of the second return value of
# encoder.streaming_forward
self.done_frames: int = 0
self.pad_length = (
params.right_context + 2
) * params.subsampling_factor + 3
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[params.blank_id] * params.context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
@property
def done(self) -> bool:
"""Return True if all the features are processed."""
return self._done
def set_features(
self,
features: torch.Tensor,
) -> None:
"""Set features tensor of current utterance."""
assert features.dim() == 2, features.dim()
self.features = torch.nn.functional.pad(
features,
(0, 0, 0, self.pad_length),
mode="constant",
value=self.LOG_EPS,
)
self.num_frames = self.features.size(0)
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
"""Consume chunk_size frames of features"""
chunk_length = chunk_size + self.pad_length
ret_length = min(
self.num_frames - self.num_processed_frames, chunk_length
)
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames # noqa
+ ret_length
]
self.num_processed_frames += chunk_size
if self.num_processed_frames >= self.num_frames:
self._done = True
return ret_features, ret_length
def decoding_result(self) -> List[int]:
"""Obtain current decoding result."""
if self.params.decoding_method == "greedy_search":
return self.hyp[self.params.context_size :] # noqa
elif self.params.decoding_method == "modified_beam_search":
best_hyp = self.hyps.get_most_probable(length_norm=True)
return best_hyp.ys[self.params.context_size :] # noqa
else:
assert self.params.decoding_method == "fast_beam_search"
return self.hyp

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/encoder_interface.py

View File

@ -0,0 +1,209 @@
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage for offline:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
--lang-dir data/lang_char \
--epoch 4 \
--avg 1
It will generate a file exp_dir/pretrained.pt for offline ASR.
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \
--lang-dir data/lang_char \
--epoch 4 \
--avg 1 \
--jit True
It will generate a file exp_dir/cpu_jit.pt for offline ASR.
Usage for streaming:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
--lang-dir data/lang_char \
--epoch 7 \
--avg 1
It will generate a file exp_dir/pretrained.pt for streaming ASR.
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
--lang-dir data/lang_char \
--epoch 7 \
--avg 1 \
--jit True
It will generate a file exp_dir/cpu_jit.pt for streaming ASR.
To use the generated file with `pruned_transducer_stateless5/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/wenetspeech/ASR
./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--epoch 4 \
--avg 1 \
--decoding-method greedy_search \
--max-duration 100 \
--lang-dir data/lang_char
"""
import argparse
import logging
from pathlib import Path
import torch
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
add_model_arguments(parser)
return parser
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/optim.py

View File

@ -0,0 +1,343 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Offline Usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp_L_offline/pretrained.pt \
--lang-dir ./data/lang_char \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp_L_offline/pretrained.pt \
--lang-dir ./data/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp_L_offline/pretrained.pt \
--lang-dir ./data/lang_char \
--method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless5/exp_L_offline/epoch-xx.pt`.
Note: ./pruned_transducer_stateless5/exp_L_offline/pretrained.pt is generated by
./pruned_transducer_stateless5/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Path to lang.
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=48000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
hyps = []
msg = f"Using {params.decoding_method}"
logging.info(msg)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py

View File

@ -0,0 +1,287 @@
# Copyright 2022 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List
import k2
import torch
import torch.nn as nn
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
from decode_stream import DecodeStream
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
) -> None:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
streams:
A list of Stream objects.
"""
assert len(streams) == encoder_out.size(0)
assert encoder_out.ndim == 3
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
# decoder_out is of shape (N, 1, decoder_out_dim)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
for t in range(T):
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# print(current_encoder_out.shape)
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
# logits'shape (batch_size, vocab_size)
logits = logits.squeeze(1).squeeze(1)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
streams[i].hyp.append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = torch.tensor(
[stream.hyp[-context_size:] for stream in streams],
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
)
decoder_out = model.joiner.decoder_proj(decoder_out)
def modified_beam_search(
model: nn.Module,
encoder_out: torch.Tensor,
streams: List[DecodeStream],
num_active_paths: int = 4,
) -> None:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The RNN-T model.
encoder_out:
A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of
the encoder model.
streams:
A list of stream objects.
num_active_paths:
Number of active paths during the beam search.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert len(streams) == encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size = len(streams)
T = encoder_out.size(1)
B = [stream.hyps for stream in streams]
for t in range(T):
current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.stack(
[hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out.unsqueeze(1),
project_input=False,
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
num_active_paths
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
for i in range(batch_size):
streams[i].hyps = B[i]
def fast_beam_search_one_best(
model: nn.Module,
encoder_out: torch.Tensor,
processed_lens: torch.Tensor,
streams: List[DecodeStream],
beam: float,
max_states: int,
max_contexts: int,
) -> None:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first generated by Fsa-based beam search, then we get the
recognition by applying shortest path on the lattice.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
processed_lens:
A tensor of shape (N,) containing the number of processed frames
in `encoder_out` before padding.
streams:
A list of stream objects.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
"""
assert encoder_out.ndim == 3
B, T, C = encoder_out.shape
assert B == len(streams)
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(streams[i].rnnt_decoding_stream)
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(processed_lens.tolist())
best_path = one_best_decoding(lattice)
hyp_tokens = get_texts(best_path)
for i in range(B):
streams[i].hyp = hyp_tokens[i]

View File

@ -0,0 +1,678 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
python pruned_transducer_stateless5/streaming_decode.py \
--epoch 7 \
--avg 1 \
--decode-chunk-size 16 \
--left-context 64 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
--decoding-method greedy_search \
--num-decode-streams 2000
(2) modified beam search
python pruned_transducer_stateless5/streaming_decode.py \
--epoch 7 \
--avg 1 \
--decode-chunk-size 16 \
--left-context 64 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
--decoding-method modified_beam_search \
--num-decode-streams 2000
(3) fast beam search
python pruned_transducer_stateless5/streaming_decode.py \
--epoch 7 \
--avg 1 \
--decode-chunk-size 16 \
--left-context 64 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
--decoding-method fast_beam_search \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless5/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num-active-paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=0,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
processed_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = [
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_lens,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(
model=model, encoder_out=encoder_out, streams=decode_streams
)
elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
decode_stream.set_features(fbank(samples.to(device)))
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].decoding_result()
decode_results.append(
(
list(decode_streams[i].ground_truth),
[lexicon.token_table[idx] for idx in hyp],
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
hyp = decode_streams[i].decoding_result()
decode_results.append(
(
list(decode_streams[i].ground_truth),
[lexicon.token_table[idx] for idx in hyp],
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
# sort results so we can easily compare the difference between two
# recognition results
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
params.causal_convolution = True
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
wenetspeech = WenetSpeechAsrDataModule(args)
dev_cuts = wenetspeech.valid_cuts()
test_net_cuts = wenetspeech.test_net_cuts()
test_meeting_cuts = wenetspeech.test_meeting_cuts()
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_cuts = [dev_cuts, test_net_cuts, test_meeting_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
lexicon=lexicon,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff