Add Streaming Zipformer-Transducer recipe for KsponSpeech (#1651)

This commit is contained in:
Seung Hyun Lee 2024-06-16 17:20:44 +09:00 committed by GitHub
parent 890eeec82c
commit c13c7aa30b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 4294 additions and 0 deletions

View File

@ -0,0 +1,32 @@
# Introduction
KsponSpeech is a large-scale spontaneous speech corpus of Korean.
This corpus contains 969 hours of open-domain dialog utterances,
spoken by about 2,000 native Korean speakers in a clean environment.
All data were constructed by recording the dialogue of two people
freely conversing on a variety of topics and manually transcribing the utterances.
The transcription provides a dual transcription consisting of orthography and pronunciation,
and disfluency tags for spontaneity of speech, such as filler words, repeated words, and word fragments.
The original audio data has a pcm extension.
During preprocessing, it is converted into a file in the flac extension and saved anew.
KsponSpeech is publicly available on an open data hub site of the Korea government.
The dataset must be downloaded manually.
For more details, please visit:
- Dataset: https://aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&aihubDataSe=realm&dataSetSn=123
- Paper: https://www.mdpi.com/2076-3417/10/19/6936
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder. The following table lists the differences among them.
| | Encoder | Decoder | Comment |
| ---------------------------------------- | -------------------- | ------------------ | ------------------------------------------------- |
| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 |
The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer.

View File

@ -0,0 +1,70 @@
## Results
### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer)
#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
Number of model parameters: 79,022,891, i.e., 79.02 M
##### Training on KsponSpeech (with MUSAN)
Model: [johnBamma/icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12](https://huggingface.co/johnBamma/icefall-asr-ksponspeech-pruned-transducer-stateless7-streaming-2024-06-12)
The CERs are:
| decoding method | chunk size | eval_clean | eval_other | comment | decoding mode |
|----------------------|------------|------------|------------|---------------------|----------------------|
| greedy search | 320ms | 10.21 | 11.07 | --epoch 30 --avg 9 | simulated streaming |
| greedy search | 320ms | 10.22 | 11.07 | --epoch 30 --avg 9 | chunk-wise |
| fast beam search | 320ms | 10.21 | 11.04 | --epoch 30 --avg 9 | simulated streaming |
| fast beam search | 320ms | 10.25 | 11.08 | --epoch 30 --avg 9 | chunk-wise |
| modified beam search | 320ms | 10.13 | 10.88 | --epoch 30 --avg 9 | simulated streaming |
| modified beam search | 320ms | 10.1 | 10.93 | --epoch 30 --avg 9 | chunk-size |
| greedy search | 640ms | 9.94 | 10.82 | --epoch 30 --avg 9 | simulated streaming |
| greedy search | 640ms | 10.04 | 10.85 | --epoch 30 --avg 9 | chunk-wise |
| fast beam search | 640ms | 10.01 | 10.81 | --epoch 30 --avg 9 | simulated streaming |
| fast beam search | 640ms | 10.04 | 10.7 | --epoch 30 --avg 9 | chunk-wise |
| modified beam search | 640ms | 9.91 | 10.72 | --epoch 30 --avg 9 | simulated streaming |
| modified beam search | 640ms | 9.92 | 10.72 | --epoch 30 --avg 9 | chunk-size |
Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
The training command is:
```bash
./pruned_transducer_stateless7_streaming/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7_streaming/exp \
--max-duration 750 \
--enable-musan True
```
The simulated streaming decoding command (e.g., chunk-size=320ms) is:
```bash
for m in greedy_search fast_beam_search modified_beam_search; do
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 30 \
--avg 9 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method $m
done
```
The streaming chunk-size decoding command (e.g., chunk-size=320ms) is:
```bash
for m in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless7_streaming/streaming_decode.py \
--epoch 30 \
--avg 9 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--decoding-method $m \
--decode-chunk-len 32 \
--num-decode-streams 2000
done
```

View File

View File

@ -0,0 +1,185 @@
#!/usr/bin/env python3
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import sentencepiece as spm
import torch
from filter_cuts import filter_cuts
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to the bpe.model. If not None, we will remove short and
long utterances before extracting features""",
)
parser.add_argument(
"--dataset",
type=str,
help="""Dataset parts to compute fbank. If None, we will use all""",
)
parser.add_argument(
"--perturb-speed",
type=str2bool,
default=True,
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
)
parser.add_argument(
"--data-dir",
type=str,
default="data",
help="""Path of data directory""",
)
return parser.parse_args()
def compute_fbank_speechtools(
bpe_model: Optional[str] = None,
dataset: Optional[str] = None,
perturb_speed: Optional[bool] = False,
data_dir: Optional[str] = "data",
):
src_dir = Path(data_dir) / "manifests"
output_dir = Path(data_dir) / "fbank"
num_jobs = min(4, os.cpu_count())
num_mel_bins = 80
if bpe_model:
logging.info(f"Loading {bpe_model}")
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
if dataset is None:
dataset_parts = (
"train",
"dev",
"eval_clean",
"eval_other",
)
else:
dataset_parts = dataset.split(" ", -1)
prefix = "ksponspeech"
suffix = "jsonl.gz"
logging.info(f"Read manifests...")
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
if torch.cuda.is_available():
# Use cuda for fbank compute
device = "cuda"
else:
device = "cpu"
logging.info(f"Device: {device}")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, device=device))
with get_executor() as ex: # Initialize the executor only once.
logging.info(f"Executor: {ex}")
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
# Filter duration
cut_set = cut_set.filter(
lambda x: x.duration > 1 and x.sampling_rate == 16000
)
if "train" in partition:
if bpe_model:
cut_set = filter_cuts(cut_set, sp)
if perturb_speed:
logging.info(f"Doing speed perturb")
cut_set = (
cut_set
+ cut_set.perturb_speed(0.9)
+ cut_set.perturb_speed(1.1)
)
logging.info(f"Compute & Store features...")
if device == "cuda":
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
num_workers=4,
storage_type=LilcomChunkyWriter,
)
else:
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(output_dir / cuts_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
compute_fbank_speechtools(
bpe_model=args.bpe_model,
dataset=args.dataset,
perturb_speed=args.perturb_speed,
data_dir=args.data_dir,
)

View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory `src_dir` (default is data/manifests).
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
MonoCut,
WhisperFbank,
WhisperFbankConfig,
combine,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan(
src_dir: str = "data/manifests",
num_mel_bins: int = 80,
whisper_fbank: bool = False,
output_dir: str = "data/fbank",
):
src_dir = Path(src_dir)
output_dir = Path(output_dir)
num_jobs = min(15, os.cpu_count())
dataset_parts = (
"music",
"speech",
"noise",
)
prefix = "musan"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = output_dir / "musan_cuts.jsonl.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once.
# create chunks of Musan with duration 5 - 10 seconds
musan_cuts = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(is_cut_long)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/musan_feats",
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomChunkyWriter,
)
)
musan_cuts.to_file(musan_cuts_path)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--src-dir",
type=str,
default="data/manifests",
help="Source manifests directory.",
)
parser.add_argument(
"--num-mel-bins",
type=int,
default=80,
help="""The number of mel bins for Fbank""",
)
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
parser.add_argument(
"--output-dir",
type=str,
default="data/fbank",
help="Output directory. Default: data/fbank.",
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_musan(
src_dir=args.src_dir,
num_mel_bins=args.num_mel_bins,
whisper_fbank=args.whisper_fbank,
output_dir=args.output_dir,
)

View File

@ -0,0 +1,157 @@
#!/usr/bin/env python3
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script removes short and long utterances from a cutset.
Caution:
You may need to tune the thresholds for your own dataset.
Usage example:
python3 ./local/filter_cuts.py \
--bpe-model data/lang_bpe_5000/bpe.model \
--in-cuts data/fbank/speechtools_cuts_test.jsonl.gz \
--out-cuts data/fbank-filtered/speechtools_cuts_test.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=Path,
help="Path to the bpe.model",
)
parser.add_argument(
"--in-cuts",
type=Path,
help="Path to the input cutset",
)
parser.add_argument(
"--out-cuts",
type=Path,
help="Path to the output cutset",
)
return parser.parse_args()
def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor):
total = 0 # number of total utterances before removal
removed = 0 # number of removed utterances
def remove_short_and_long_utterances(c: Cut):
"""Return False to exclude the input cut"""
nonlocal removed, total
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ./display_manifest_statistics.py
#
# You should use ./display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
total += 1
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
removed += 1
return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./pruned_transducer_stateless2/conformer.py, the
# conv module uses the following expression
# for subsampling
if c.num_frames is None:
num_frames = c.duration * 100 # approximate
else:
num_frames = c.num_frames
T = ((num_frames - 1) // 2 - 1) // 2
# Note: for ./lstm_transducer_stateless/lstm.py, the formula is
# T = ((num_frames - 3) // 2 - 1) // 2
# Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is
# T = ((num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
removed += 1
return False
return True
# We use to_eager() here so that we can print out the value of total
# and removed below.
ans = cut_set.filter(remove_short_and_long_utterances).to_eager()
ratio = removed / total * 100
logging.info(
f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed."
)
return ans
def main():
args = get_args()
logging.info(vars(args))
if args.out_cuts.is_file():
logging.info(f"{args.out_cuts} already exists - skipping")
return
assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist"
assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist"
sp = spm.SentencePieceProcessor()
sp.load(str(args.bpe_model))
cut_set = load_manifest_lazy(args.in_cuts)
assert isinstance(cut_set, CutSet)
cut_set = filter_cuts(cut_set, sp)
logging.info(f"Saving to {args.out_cuts}")
args.out_cuts.parent.mkdir(parents=True, exist_ok=True)
cut_set.to_file(args.out_cuts)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,115 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# You can install sentencepiece via:
#
# pip install sentencepiece
#
# Due to an issue reported in
# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030
#
# Please install a version >=0.1.96
import argparse
import shutil
from pathlib import Path
from typing import Dict
import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--transcript",
type=str,
help="Training transcript.",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def generate_tokens(lang_dir: Path):
"""
Generate the tokens.txt from a bpe model.
"""
sp = spm.SentencePieceProcessor()
sp.load(str(lang_dir / "bpe.model"))
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f:
for sym, i in token2id.items():
f.write(f"{sym} {i}\n")
def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
model_type = "unigram"
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = args.transcript
character_coverage = 1.0
input_sentence_size = 100000000
user_defined_symbols = ["<blk>", "<sos/eos>"]
unk_id = len(user_defined_symbols)
# Note: unk_id is fixed to 2.
# If you change it, you should also change other
# places that are using it.
model_file = Path(model_prefix + ".model")
if not model_file.is_file():
spm.SentencePieceTrainer.train(
input=train_text,
vocab_size=vocab_size,
model_type=model_type,
model_prefix=model_prefix,
input_sentence_size=input_sentence_size,
character_coverage=character_coverage,
user_defined_symbols=user_defined_symbols,
unk_id=unk_id,
bos_id=-1,
eos_id=-1,
)
else:
print(f"{model_file} exists - skipping")
return
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
generate_tokens(lang_dir)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,101 @@
#!/usr/bin/env python3
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
- Supervision time bounds are within cut time bounds
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/fbank/speechtools_cuts_train.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.dataset.speech_recognition import validate_for_asr
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def validate_one_supervision_per_cut(c: Cut):
if len(c.supervisions) != 1:
raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
def validate_supervision_and_cut_time_bounds(c: Cut):
tol = 2e-3 # same tolerance as in 'validate_for_asr()'
s = c.supervisions[0]
# Supervision start time is relative to Cut ...
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
if s.start < -tol:
raise ValueError(
f"{c.id}: Supervision start time {s.start} must not be negative."
)
if s.start > tol:
raise ValueError(
f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`."
)
if c.start + s.end > c.end + tol:
raise ValueError(
f"{c.id}: Supervision end time {c.start+s.end} is larger "
f"than cut end time {c.end}"
)
def main():
args = get_args()
manifest = args.manifest
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest_lazy(manifest)
assert isinstance(cut_set, CutSet)
for c in cut_set:
validate_one_supervision_per_cut(c)
validate_supervision_and_cut_time_bounds(c)
# Validation from K2 training
# - checks supervision start is 0
# - checks supervision.duration is not longer than cut.duration
# - there is tolerance 2ms
validate_for_asr(cut_set)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
This recipe implements Streaming Zipformer-Transducer model.

View File

@ -0,0 +1,415 @@
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class KsponSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader.
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=self.args.num_buckets * 2000,
shuffle_buffer_size=self.args.num_buckets * 5000,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts.")
return load_manifest_lazy(
self.args.manifest_dir / "ksponspeech_cuts_train.jsonl.gz"
)
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ksponspeech_cuts_dev.jsonl.gz"
)
@lru_cache()
def eval_clean_cuts(self) -> CutSet:
logging.info("About to get eval_clean cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ksponspeech_cuts_eval_clean.jsonl.gz"
)
@lru_cache()
def eval_other_cuts(self) -> CutSet:
logging.info("About to get eval_other cuts")
return load_manifest_lazy(
self.args.manifest_dir / "ksponspeech_cuts_eval_other.jsonl.gz"
)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py

View File

@ -0,0 +1,993 @@
#!/usr/bin/env python3
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import KsponSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_lm_rescore,
modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from train import add_model_arguments, get_params, get_transducer_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--use-shallow-fusion",
type=str2bool,
default=False,
help="""Use neural network LM for shallow fusion.
If you want to use LODR, you will also need to set this to true
""",
)
parser.add_argument(
"--lm-type",
type=str,
default="rnn",
help="Type of NN lm",
choices=["rnn", "transformer"],
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.3,
help="""The scale of the neural network LM
Used only when `--use-shallow-fusion` is set to True.
""",
)
parser.add_argument(
"--tokens-ngram",
type=int,
default=2,
help="""The order of the ngram lm.
""",
)
parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="ID of the backoff symbol in the ngram LM",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network language model.
ngram_lm:
A ngram language model
ngram_lm_scale:
The scale for the ngram language model.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
feature_lens += 30
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, 30),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
lm_scale_list=lm_scale_list,
)
elif params.decoding_method == "modified_beam_search_lm_rescore_LODR":
lm_scale_list = [0.02 * i for i in range(2, 30)]
ans_dict = modified_beam_search_lm_rescore_LODR(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
LM=LM,
LODR_lm=ngram_lm,
sp=sp,
lm_scale_list=lm_scale_list,
)
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
elif params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
return ans
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
ngram_lm:
A n-gram LM to be used for LODR.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_cers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out CERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
cer = write_error_stats(
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_cers[key] = cer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_cers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_cers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
KsponSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_shallow_fusion:
params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}"
if "LODR" in params.decoding_method:
params.suffix += (
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
model.encoder.decode_chunk_size,
params.decode_chunk_len,
)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
# only load the neural network LM if required
if params.use_shallow_fusion or params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
"modified_beam_search_lm_shallow_fusion",
"modified_beam_search_LODR",
):
LM = LmScorer(
lm_type=params.lm_type,
params=params,
device=device,
lm_scale=params.lm_scale,
)
LM.to(device)
LM.eval()
else:
LM = None
# only load N-gram LM when needed
if params.decoding_method == "modified_beam_search_lm_rescore_LODR":
try:
import kenlm
except ImportError:
print("Please install kenlm first. You can use")
print(" pip install https://github.com/kpu/kenlm/archive/master.zip")
print("to install it")
import sys
sys.exit(-1)
ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa")
logging.info(f"lm filename: {ngram_file_name}")
ngram_lm = kenlm.Model(ngram_file_name)
ngram_lm_scale = None # use a list to search
elif params.decoding_method == "modified_beam_search_LODR":
lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"Loading token level lm: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")
ngram_lm_scale = params.ngram_lm_scale
else:
ngram_lm = None
ngram_lm_scale = None
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
ksponspeech = KsponSpeechAsrDataModule(args)
eval_clean_cuts = ksponspeech.eval_clean_cuts()
eval_other_cuts = ksponspeech.eval_other_cuts()
eval_clean_dl = ksponspeech.test_dataloaders(eval_clean_cuts)
eval_other_dl = ksponspeech.test_dataloaders(eval_other_cuts)
test_sets = ["eval_clean", "eval_other"]
test_dl = [eval_clean_dl, eval_other_dl]
import time
for test_set, test_dl in zip(test_sets, test_dl):
start = time.time()
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)
logging.info(f"Elasped time for {test_set}: {time.time() - start}")
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py

View File

@ -0,0 +1,619 @@
#!/usr/bin/env python3
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless7_streaming/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-len 32 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--decoding-method greedy_search \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import KsponSpeechAsrDataModule
from decode_stream import DecodeStream
from lhotse import CutSet, Fbank, FbankConfig
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states, unstack_states
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
processed_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
# factor in encoders is 8.
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
tail_length = 23
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankConfig(
device=device,
dither=0.0,
snip_edges=False,
sampling_rate=16000,
num_mel_bins=80,
high_freq=-400.0,
)
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = model.encoder.get_init_state(device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
# - this is to avoid sending [-32k,+32k] signal in...
# - some lhotse AudioTransform classes can make the signal
# be out of range [-1, 1], hence the tolerance 10
assert (
np.abs(audio).max() <= 10
), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank.extract(samples.to(device), sampling_rate=16000)
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_cers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out CERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
cer = write_error_stats(
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_cers[key] = cer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_cers = sorted(test_set_cers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tCER", file=f)
for key, val in test_set_cers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, CER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_cers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
KsponSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
ksponspeech = KsponSpeechAsrDataModule(args)
eval_clean_cuts = ksponspeech.eval_clean_cuts()
eval_other_cuts = ksponspeech.eval_other_cuts()
test_sets = ["eval_clean", "eval_other"]
test_cuts = [eval_clean_cuts, eval_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,187 @@
#!/usr/bin/env python3
# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/ksponspeech/ASR
python ./pruned_transducer_stateless7_streaming/test_model.py
"""
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,4,3,2,4"
params.feedforward_dims = "1024,1024,2048,2048,1024"
params.nhead = "8,8,8,8,8"
params.encoder_dims = "384,384,384,384,384"
params.attention_dims = "192,192,192,192,192"
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 512
params.joiner_dim = 512
params.num_left_chunks = 4
params.short_chunk_size = 50
params.decode_chunk_len = 32
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# Test jit script
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
print("Using torch.jit.script")
model = torch.jit.script(model)
def test_model_small():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,2,2,2,2"
params.feedforward_dims = "256,256,512,512,256"
params.nhead = "4,4,4,4,4"
params.encoder_dims = "128,128,128,128,128"
params.attention_dims = "96,96,96,96,96"
params.encoder_unmasked_dims = "96,96,96,96,96"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 320
params.joiner_dim = 320
params.num_left_chunks = 4
params.short_chunk_size = 50
params.decode_chunk_len = 32
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
import pdb
pdb.set_trace()
# Test jit script
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
print("Using torch.jit.script")
model = torch.jit.script(model)
def test_model_jit_trace():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,4,3,2,4"
params.feedforward_dims = "1024,1024,2048,2048,1024"
params.nhead = "8,8,8,8,8"
params.encoder_dims = "384,384,384,384,384"
params.attention_dims = "192,192,192,192,192"
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 512
params.joiner_dim = 512
params.num_left_chunks = 4
params.short_chunk_size = 50
params.decode_chunk_len = 32
model = get_transducer_model(params)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
convert_scaled_to_non_scaled(model, inplace=True)
# Test encoder
def _test_encoder():
encoder = model.encoder
assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
encoder.decode_chunk_size,
params.decode_chunk_len,
)
T = params.decode_chunk_len + 7
x = torch.zeros(1, T, 80, dtype=torch.float32)
x_lens = torch.full((1,), T, dtype=torch.int32)
states = encoder.get_init_state(device=x.device)
encoder.__class__.forward = encoder.__class__.streaming_forward
traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
states1 = encoder.get_init_state(device=x.device)
states2 = traced_encoder.get_init_state(device=x.device)
for i in range(5):
x = torch.randn(1, T, 80, dtype=torch.float32)
x_lens = torch.full((1,), T, dtype=torch.int32)
y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
y2, _, states2 = traced_encoder(x, x_lens, states2)
assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
# Test decoder
def _test_decoder():
decoder = model.decoder
y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_decoder = torch.jit.trace(decoder, (y, need_pad))
d1 = decoder(y, need_pad)
d2 = traced_decoder(y, need_pad)
assert torch.equal(d1, d2), (d1 - d2).abs().mean()
# Test joiner
def _test_joiner():
joiner = model.joiner
encoder_out_dim = joiner.encoder_proj.weight.shape[1]
decoder_out_dim = joiner.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
j1 = joiner(encoder_out, decoder_out)
j2 = traced_joiner(encoder_out, decoder_out)
assert torch.equal(j1, j2), (j1 - j2).abs().mean()
_test_encoder()
_test_decoder()
_test_joiner()
def main():
test_model_small()
test_model_jit_trace()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py

1
egs/ksponspeech/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/