CSJ pruned_transducer_stateless7_streaming (#892)

* update manifest stats

* update transcript configs

* lang_char and compute_fbanks

* save cuts in fbank_dir

* add core codes

* update decode.py

* Create local/utils

* tidy up

* parse raw in prepare_lang_char.py

* update manifest stats

* update transcript configs

* lang_char and compute_fbanks

* save cuts in fbank_dir

* add core codes

* update decode.py

* Create local/utils

* tidy up

* parse raw in prepare_lang_char.py

* working train

* Add compare_cer_transcript.py

* fix tokenizer decode, allow d2f only

* comment cleanup

* add export files and READMEs

* reword average column

* fix comments

* Update new results
This commit is contained in:
Teo Wen Shen 2023-02-13 23:19:50 +09:00 committed by GitHub
parent 25ee50e27c
commit e63a8c27f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 5847 additions and 1240 deletions

11
egs/csj/ASR/README.md Normal file
View File

@ -0,0 +1,11 @@
# Introduction
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
These are the types of architectures currently available.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming |

200
egs/csj/ASR/RESULTS.md Normal file
View File

@ -0,0 +1,200 @@
# Results
## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer)
### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
See <https://github.com/k2-fsa/icefall/pull/892> for more details.
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208>
Number of model parameters: 75688409, i.e. 75.7M.
#### training on disfluent transcript
The CERs are:
| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode |
| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- |
| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming |
| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise |
| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming |
| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise |
| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming |
| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise |
| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming |
| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise |
| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming |
| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise |
| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming |
| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise |
Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
The training command was:
```bash
./pruned_transducer_stateless7_streaming/train.py \
--feedforward-dims "1024,1024,2048,2048,1024" \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
--max-duration 375 \
--transcript-mode disfluent \
--lang data/lang_char \
--manifest-dir /mnt/host/corpus/csj/fbank \
--pad-feature 30 \
--musan-dir /mnt/host/corpus/musan/musan/fbank
```
The simulated streaming decoding command was:
```bash
for chunk in 64 32; do
for m in greedy_search fast_beam_search modified_beam_search; do
python pruned_transducer_stateless7_streaming/decode.py \
--feedforward-dims "1024,1024,2048,2048,1024" \
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
--epoch 30 \
--avg 17 \
--max-duration 350 \
--decoding-method $m \
--manifest-dir /mnt/host/corpus/csj/fbank \
--lang data/lang_char \
--transcript-mode disfluent \
--res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \
--decode-chunk-len $chunk \
--pad-feature 30 \
--gpu 0
done
done
```
The streaming chunk-wise decoding command was:
```bash
for chunk in 64 32; do
for m in greedy_search fast_beam_search modified_beam_search; do
python pruned_transducer_stateless7_streaming/streaming_decode.py \
--feedforward-dims "1024,1024,2048,2048,1024" \
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
--epoch 30 \
--avg 17 \
--max-duration 350 \
--decoding-method $m \
--manifest-dir /mnt/host/corpus/csj/fbank \
--lang data/lang_char \
--transcript-mode disfluent \
--res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \
--decode-chunk-len $chunk \
--gpu 2 \
--num-decode-streams 40
done
done
```
#### training on fluent transcript
The CERs are:
| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode |
| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- |
| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming |
| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise |
| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming |
| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise |
| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming |
| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise |
| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming |
| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise |
| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming |
| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise |
| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming |
| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise |
Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
The training command was:
```bash
./pruned_transducer_stateless7_streaming/train.py \
--feedforward-dims "1024,1024,2048,2048,1024" \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
--max-duration 375 \
--transcript-mode fluent \
--lang data/lang_char \
--manifest-dir /mnt/host/corpus/csj/fbank \
--pad-feature 30 \
--musan-dir /mnt/host/corpus/musan/musan/fbank
```
The simulated streaming decoding command was:
```bash
for chunk in 64 32; do
for m in greedy_search fast_beam_search modified_beam_search; do
python pruned_transducer_stateless7_streaming/decode.py \
--feedforward-dims "1024,1024,2048,2048,1024" \
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
--epoch 30 \
--avg 12 \
--max-duration 350 \
--decoding-method $m \
--manifest-dir /mnt/host/corpus/csj/fbank \
--lang data/lang_char \
--transcript-mode fluent \
--res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \
--decode-chunk-len $chunk \
--pad-feature 30 \
--gpu 1
done
done
```
The streaming chunk-wise decoding command was:
```bash
for chunk in 64 32; do
for m in greedy_search fast_beam_search modified_beam_search; do
python pruned_transducer_stateless7_streaming/streaming_decode.py \
--feedforward-dims "1024,1024,2048,2048,1024" \
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
--epoch 30 \
--avg 12 \
--max-duration 350 \
--decoding-method $m \
--manifest-dir /mnt/host/corpus/csj/fbank \
--lang data/lang_char \
--transcript-mode fluent \
--res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \
--decode-chunk-len $chunk \
--gpu 3 \
--num-decode-streams 40
done
done
```
#### Comparing disfluent to fluent
$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$
This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared.
| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode |
| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- |
| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming |
| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise |
| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming |
| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise |
| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming |
| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise |
| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming |
| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise |
| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming |
| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise |
| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming |
| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise |
| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | |

View File

@ -0,0 +1,94 @@
import argparse
import logging
from configparser import ConfigParser
from pathlib import Path
from typing import List
from lhotse import CutSet, SupervisionSet
from lhotse.recipes.csj import CSJSDBParser
ARGPARSE_DESCRIPTION = """
This script adds transcript modes to an existing CutSet or SupervisionSet.
"""
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=ARGPARSE_DESCRIPTION,
)
parser.add_argument(
"-f",
"--fbank-dir",
type=Path,
help="Path to directory where manifests are stored.",
)
parser.add_argument(
"-c",
"--config",
type=Path,
nargs="+",
help="Path to config file for transcript parsing.",
)
return parser.parse_args()
def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]:
parsers = []
for config_file in config_files:
config = ConfigParser()
config.optionxform = str
assert config.read(config_file), f"{config_file} could not be found."
decisions = {}
for k, v in config["DECISIONS"].items():
try:
decisions[k] = int(v)
except ValueError:
decisions[k] = v
parsers.append(
(config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions))
)
return parsers
def main():
args = get_args()
logging.basicConfig(
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
level=logging.INFO,
)
parsers = get_CSJParsers(args.config)
config = ConfigParser()
config.optionxform = str
assert config.read(args.config), args.config
decisions = {}
for k, v in config["DECISIONS"].items():
try:
decisions[k] = int(v)
except ValueError:
decisions[k] = v
logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.")
manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz")
assert manifests, f"No cuts to be found in {args.fbank_dir}"
for manifest in manifests:
results = []
logging.info(f"Adding transcript modes to {manifest.name} now.")
cutset = CutSet.from_file(manifest)
for cut in cutset:
for name, parser in parsers:
cut.supervisions[0].custom[name] = parser.parse(
cut.supervisions[0].custom["raw"]
)
cut.supervisions[0].text = ""
results.append(cut)
results = CutSet.from_items(results)
res_file = manifest.as_posix()
manifest.replace(manifest.parent / ("bak." + manifest.name))
results.to_file(res_file)
if __name__ == "__main__":
main()

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -19,9 +19,7 @@
import argparse
import logging
import os
from itertools import islice
from pathlib import Path
from random import Random
from typing import List, Tuple
import torch
@ -35,20 +33,10 @@ from lhotse import ( # See the following for why LilcomChunkyWriter is preferre
RecordingSet,
SupervisionSet,
)
from lhotse.recipes.csj import concat_csj_supervisions
# fmt: on
ARGPARSE_DESCRIPTION = """
This script follows the espnet method of splitting the remaining core+noncore
utterances into valid and train cutsets at an index which is by default 4000.
In other words, the core+noncore utterances are shuffled, where 4000 utterances
of the shuffled set go to the `valid` cutset and are not subject to speed
perturbation. The remaining utterances become the `train` cutset and are speed-
perturbed (0.9x, 1.0x, 1.1x).
"""
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
@ -57,66 +45,101 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
RNG_SEED = 42
# concat_params_train = [
# {"gap": 1.0, "maxlen": 10.0},
# {"gap": 1.5, "maxlen": 8.0},
# {"gap": 1.0, "maxlen": 18.0},
# ]
concat_params = {"gap": 1.0, "maxlen": 10.0}
def make_cutset_blueprints(
manifest_dir: Path,
split: int,
) -> List[Tuple[str, CutSet]]:
cut_sets = []
logging.info("Creating non-train cuts.")
# Create eval datasets
logging.info("Creating eval cuts.")
for i in range(1, 4):
sps = sorted(
SupervisionSet.from_file(
manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz"
),
key=lambda x: x.id,
)
cut_set = CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / f"csj_recordings_eval{i}.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz"
),
supervisions=concat_csj_supervisions(sps, **concat_params),
)
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
cut_sets.append((f"eval{i}", cut_set))
# Create train and valid cuts
logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
recording_set = RecordingSet.from_file(
manifest_dir / "csj_recordings_core.jsonl.gz"
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
supervision_set = SupervisionSet.from_file(
manifest_dir / "csj_supervisions_core.jsonl.gz"
) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
# Create excluded dataset
sps = sorted(
SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"),
key=lambda x: x.id,
)
cut_set = CutSet.from_manifests(
recordings=recording_set,
supervisions=supervision_set,
recordings=RecordingSet.from_file(
manifest_dir / "csj_recordings_excluded.jsonl.gz"
),
supervisions=concat_csj_supervisions(sps, **concat_params),
)
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
cut_set = cut_set.shuffle(Random(RNG_SEED))
cut_sets.append(("excluded", cut_set))
logging.info(
"Creating valid and train cuts from core and noncore, split at {split}."
# Create valid dataset
sps = sorted(
SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"),
key=lambda x: x.id,
)
valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
cut_set = CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "csj_recordings_valid.jsonl.gz"
),
supervisions=concat_csj_supervisions(sps, **concat_params),
)
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
cut_sets.append(("valid", cut_set))
train_set = CutSet.from_cuts(islice(cut_set, split, None))
logging.info("Creating train cuts.")
# Create train dataset
sps = sorted(
SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz")
+ SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"),
key=lambda x: x.id,
)
recording = RecordingSet.from_file(
manifest_dir / "csj_recordings_core.jsonl.gz"
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
train_set = CutSet.from_manifests(
recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params)
).trim_to_supervisions(keep_overlapping=False)
train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
cut_sets.extend([("valid", valid_set), ("train", train_set)])
cut_sets.append(("train", train_set))
return cut_sets
def get_args():
parser = argparse.ArgumentParser(
description=ARGPARSE_DESCRIPTION,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
parser.add_argument("--split", type=int, default=4000, help="Split at this index")
parser.add_argument(
"-m", "--manifest-dir", type=Path, help="Path to save manifests"
)
parser.add_argument(
"-f", "--fbank-dir", type=Path, help="Path to save fbank features"
)
return parser.parse_args()
@ -138,7 +161,7 @@ def main():
)
return
else:
cut_sets = make_cutset_blueprints(args.manifest_dir, args.split)
cut_sets = make_cutset_blueprints(args.manifest_dir)
for part, cut_set in cut_sets:
logging.info(f"Processing {part}")
cut_set = cut_set.compute_and_store_features(
@ -147,7 +170,7 @@ def main():
storage_path=(args.fbank_dir / f"feats_{part}").as_posix(),
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz")
cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz")
logging.info("All fbank computed for CSJ.")
(args.fbank_dir / ".done").touch()

View File

@ -28,9 +28,7 @@ from icefall.utils import get_executor
ARGPARSE_DESCRIPTION = """
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
# Torch's multithreaded behavior needs to be disabled or
@ -42,8 +40,6 @@ torch.set_num_interop_threads(1)
def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
# src_dir = Path("data/manifests")
# output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
num_mel_bins = 80
@ -104,8 +100,12 @@ def get_args():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
parser.add_argument(
"-m", "--manifest-dir", type=Path, help="Path to save manifests"
)
parser.add_argument(
"-f", "--fbank-dir", type=Path, help="Path to save fbank features"
)
return parser.parse_args()

View File

@ -1,320 +1,79 @@
; # This section is ignored if this file is not supplied as the first config file to
; # lhotse prepare csj
[SEGMENTS]
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
gap = 0.5
; # Maximum length of segment (s).
maxlen = 10
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
minlen = 0.02
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
; # If you intend to use a multicharacter string for gap_sym, remember to register the
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
gap_sym =
[CONSTANTS]
; # Name of this mode
MODE = disfluent
; # Suffixes to use after the word surface (no longer used)
MORPH = pos1 cForm cType2 pos2
; # Used to differentiate between A tag and A_num tag
JPN_NUM = ゼロ 零 一 二 三 四 五 六 七 八 九 十 百 千
; # Dummy character to delineate multiline words
PLUS =
[DECISIONS]
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
; # The PLUS (fullwidth) sign '' marks line boundaries for multiline entries
; # フィラー、感情表出系感動詞
; # 0 to remain, 1 to delete
; # Example: '(F ぎょっ)'
F = 0
; # Example: '(L (F ン))', '比べ(F えー)る'
F^ = 0
; # 言い直し、いいよどみなどによる語断片
; # 0 to remain, 1 to delete
; # Example: '(D だ)(D だいが) 大学の学部の会議'
D = 0
; # Example: '(L (D ドゥ)(D ヒ))'
D^ = 0
; # 助詞、助動詞、接辞の言い直し
; # 0 to remain, 1 to delete
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
D2 = 0
; # Example: '(X (D2 ))'
D2^ = 0
; # 聞き取りや語彙の判断に自信がない場合
; # 0 to remain, 1 to delete
; # Example: (? 字数) の
; # If no option: empty string is returned regardless of output
; # Example: '(?) で'
? = 0
; # Example: '(D (? すー))+そう+です+よ+ね'
?^ = 0
; # タグ?で、値は複数の候補が想定される場合
; # 0 for main guess with matching morph info, 1 for second guess
; # Example: '(? 次数, 実数)', '(? これ,ここで)(? 説明+し+た+方+が+いい+か+な)'
?, = 0
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
?,^ = 0
; # 音や言葉に関するメタ的な引用
; # 0 to remain, 1 to delete
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
M = 0
; # Example: '(L (M ヒ)(M ヒ))', '(L (M (? ヒ+ヒ)))'
M^ = 0
; # 外国語や古語、方言など
; # 0 to remain, 1 to delete
; # Example: '(O ザッツファイン)'
O = 0
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
O^ = 0
; # 講演者の名前、差別語、誹謗中傷など
; # 0 to remain, 1 to delete
; # Example: '国語研の (R ××) です'
R = 0
R^ = 0
; # 非朗読対象発話(朗読における言い間違い等)
; # 0 to remain, 1 to delete
; # Example: '(X 実際は) 実際には'
X = 0
; # Example: '(L (X (D2 ニ)))'
X^ = 0
; # アルファベットや算用数字、記号の表記
; # 0 to use Japanese form, 1 to use alphabet form
; # Example: '(A シーディーアール;)'
A = 1
; # Example: 'スモール(A エヌ;)', 'ラージ(A キュー;)', '(A ティーエフ;)(A アイディーエフ;)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
A^ = 1
; # タグAで、単語は算用数字の場合
; # 0 to use Japanese form, 1 to use Arabic numerals
; # Example: (A 二千;)
A_num = eval:self.notag
A_num^ = eval:self.notag
A_num = 0
; # 何らかの原因で漢字表記できなくなった場合
; # 0 to use broken form, 1 to use orthodox form
; # Example: '(K たち (F えー) ばな;橘)'
K = 1
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
K^ = 1
; # 転訛、発音の怠けなど、一時的な発音エラー
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(W ギーツ;ギジュツ)'
W = 1
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
W^ = 1
; # 語の読みに関する知識レベルのいい間違い
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(B シブタイ;ジュータイ)'
B = 0
; # Example: 'データー(B カズ;スー)'
B^ = 0
; # 笑いながら発話
; # 0 to remain, 1 to delete
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
= 0
; # Example: 'コク(笑 サイ+(D オン))',
笑^ = 0
; # 泣きながら発話
; # 0 to remain, 1 to delete
; # Example: '(泣 ドンナニ)'
= 0
泣^ = 0
; # 咳をしながら発話
; # 0 to remain, 1 to delete
; # Example: 'シャ(咳 リン) '
= 0
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
咳^ = 0
; # ささやき声や独り言などの小さな声
; # 0 to remain, 1 to delete
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
L = 0
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
L^ = 0
[REPLACEMENTS]
; # ボーカルフライなどで母音が同定できない場合
<FV> =
; # 「うん/うーん/ふーん」の音の特定が困難な場合
<VN> =
; # 非語彙的な母音の引き延ばし
<H> =
; # 非語彙的な子音の引き延ばし
<Q> =
; # 言語音と独立に講演者の笑いが生じている場合
<笑> =
; # 言語音と独立に講演者の咳が生じている場合
<咳> =
; # 言語音と独立に講演者の息が生じている場合
<息> =
; # 講演者の泣き声
<泣> =
; # 聴衆(司会者なども含む)の発話
<フロア発話> =
; # 聴衆の笑い
<フロア笑> =
; # 聴衆の拍手
<拍手> =
; # 講演者が発表中に用いたデモンストレーションの音声
<デモ> =
; # 学会講演に発表時間を知らせるためにならすベルの音
<ベル> =
; # 転記単位全体が再度読み直された場合
<朗読間違い> =
; # 上記以外の音で特に目立った音
<雑音> =
; # 0.2秒以上のポーズ
<P> =
; # Redacted information, for R
; # It is \x00D7 multiplication sign, not your normal 'x'
× = ×
[FIELDS]
; # Time information for segment
time = 3
; # Word surface
surface = 5
; # Word surface root form without CSJ tags
notag = 9
; # Part Of Speech
pos1 = 11
; # Conjugated Form
cForm = 12
; # Conjugation Type
cType1 = 13
; # Subcategory of POS
pos2 = 14
; # Euphonic Change / Subcategory of Conjugation Type
cType2 = 15
; # Other information
other = 16
; # Pronunciation for lexicon
pron = 10
; # Speaker ID
spk_id = 2
[KATAKANA2ROMAJI]
= 'a
= 'i
= 'u
= 'e
= 'o
= ka
= ki
= ku
= ke
= ko
= ga
= gi
= gu
= ge
= go
= sa
= si
= su
= se
= so
= za
= zi
= zu
= ze
= zo
= ta
= ti
= tu
= te
= to
= da
= di
= du
= de
= do
= na
= ni
= nu
= ne
= no
= ha
= hi
= hu
= he
= ho
= ba
= bi
= bu
= be
= bo
= pa
= pi
= pu
= pe
= po
= ma
= mi
= mu
= me
= mo
= ya
= yu
= yo
= ra
= ri
= ru
= re
= ro
= wa
= we
= wi
= wo
= ŋ
= q
= -
キャ = kǐa
キュ = kǐu
キョ = kǐo
ギャ = gǐa
ギュ = gǐu
ギョ = gǐo
シャ = sǐa
シュ = sǐu
ショ = sǐo
ジャ = zǐa
ジュ = zǐu
ジョ = zǐo
チャ = tǐa
チュ = tǐu
チョ = tǐo
ヂャ = dǐa
ヂュ = dǐu
ヂョ = dǐo
ニャ = nǐa
ニュ = nǐu
ニョ = nǐo
ヒャ = hǐa
ヒュ = hǐu
ヒョ = hǐo
ビャ = bǐa
ビュ = bǐu
ビョ = bǐo
ピャ = pǐa
ピュ = pǐu
ピョ = pǐo
ミャ = mǐa
ミュ = mǐu
ミョ = mǐo
リャ = rǐa
リュ = rǐu
リョ = rǐo
= a
= i
= u
= e
= o
= ʍ
= vu
= ǐa
= ǐu
= ǐo

View File

@ -1,320 +1,79 @@
; # This section is ignored if this file is not supplied as the first config file to
; # lhotse prepare csj
[SEGMENTS]
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
gap = 0.5
; # Maximum length of segment (s).
maxlen = 10
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
minlen = 0.02
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
; # If you intend to use a multicharacter string for gap_sym, remember to register the
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
gap_sym =
[CONSTANTS]
; # Name of this mode
MODE = fluent
; # Suffixes to use after the word surface (no longer used)
MORPH = pos1 cForm cType2 pos2
; # Used to differentiate between A tag and A_num tag
JPN_NUM = ゼロ 零 一 二 三 四 五 六 七 八 九 十 百 千
; # Dummy character to delineate multiline words
PLUS =
[DECISIONS]
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
; # The PLUS (fullwidth) sign '' marks line boundaries for multiline entries
; # フィラー、感情表出系感動詞
; # 0 to remain, 1 to delete
; # Example: '(F ぎょっ)'
F = 1
; # Example: '(L (F ン))', '比べ(F えー)る'
F^ = 1
; # 言い直し、いいよどみなどによる語断片
; # 0 to remain, 1 to delete
; # Example: '(D だ)(D だいが) 大学の学部の会議'
D = 1
; # Example: '(L (D ドゥ)(D ヒ))'
D^ = 1
; # 助詞、助動詞、接辞の言い直し
; # 0 to remain, 1 to delete
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
D2 = 1
; # Example: '(X (D2 ))'
D2^ = 1
; # 聞き取りや語彙の判断に自信がない場合
; # 0 to remain, 1 to delete
; # Example: (? 字数) の
; # If no option: empty string is returned regardless of output
; # Example: '(?) で'
? = 0
; # Example: '(D (? すー))+そう+です+よ+ね'
?^ = 0
; # タグ?で、値は複数の候補が想定される場合
; # 0 for main guess with matching morph info, 1 for second guess
; # Example: '(? 次数, 実数)', '(? これ,ここで)(? 説明+し+た+方+が+いい+か+な)'
?, = 0
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
?,^ = 0
; # 音や言葉に関するメタ的な引用
; # 0 to remain, 1 to delete
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
M = 0
; # Example: '(L (M ヒ)(M ヒ))', '(L (M (? ヒ+ヒ)))'
M^ = 0
; # 外国語や古語、方言など
; # 0 to remain, 1 to delete
; # Example: '(O ザッツファイン)'
O = 0
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
O^ = 0
; # 講演者の名前、差別語、誹謗中傷など
; # 0 to remain, 1 to delete
; # Example: '国語研の (R ××) です'
R = 0
R^ = 0
; # 非朗読対象発話(朗読における言い間違い等)
; # 0 to remain, 1 to delete
; # Example: '(X 実際は) 実際には'
X = 0
; # Example: '(L (X (D2 ニ)))'
X^ = 0
; # アルファベットや算用数字、記号の表記
; # 0 to use Japanese form, 1 to use alphabet form
; # Example: '(A シーディーアール;)'
A = 1
; # Example: 'スモール(A エヌ;)', 'ラージ(A キュー;)', '(A ティーエフ;)(A アイディーエフ;)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
A^ = 1
; # タグAで、単語は算用数字の場合
; # 0 to use Japanese form, 1 to use Arabic numerals
; # Example: (A 二千;)
A_num = eval:self.notag
A_num^ = eval:self.notag
A_num = 0
; # 何らかの原因で漢字表記できなくなった場合
; # 0 to use broken form, 1 to use orthodox form
; # Example: '(K たち (F えー) ばな;橘)'
K = 1
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
K^ = 1
; # 転訛、発音の怠けなど、一時的な発音エラー
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(W ギーツ;ギジュツ)'
W = 1
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
W^ = 1
; # 語の読みに関する知識レベルのいい間違い
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(B シブタイ;ジュータイ)'
B = 0
; # Example: 'データー(B カズ;スー)'
B^ = 0
; # 笑いながら発話
; # 0 to remain, 1 to delete
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
= 0
; # Example: 'コク(笑 サイ+(D オン))',
笑^ = 0
; # 泣きながら発話
; # 0 to remain, 1 to delete
; # Example: '(泣 ドンナニ)'
= 0
泣^ = 0
; # 咳をしながら発話
; # 0 to remain, 1 to delete
; # Example: 'シャ(咳 リン) '
= 0
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
咳^ = 0
; # ささやき声や独り言などの小さな声
; # 0 to remain, 1 to delete
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
L = 0
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
L^ = 0
[REPLACEMENTS]
; # ボーカルフライなどで母音が同定できない場合
<FV> =
; # 「うん/うーん/ふーん」の音の特定が困難な場合
<VN> =
; # 非語彙的な母音の引き延ばし
<H> =
; # 非語彙的な子音の引き延ばし
<Q> =
; # 言語音と独立に講演者の笑いが生じている場合
<笑> =
; # 言語音と独立に講演者の咳が生じている場合
<咳> =
; # 言語音と独立に講演者の息が生じている場合
<息> =
; # 講演者の泣き声
<泣> =
; # 聴衆(司会者なども含む)の発話
<フロア発話> =
; # 聴衆の笑い
<フロア笑> =
; # 聴衆の拍手
<拍手> =
; # 講演者が発表中に用いたデモンストレーションの音声
<デモ> =
; # 学会講演に発表時間を知らせるためにならすベルの音
<ベル> =
; # 転記単位全体が再度読み直された場合
<朗読間違い> =
; # 上記以外の音で特に目立った音
<雑音> =
; # 0.2秒以上のポーズ
<P> =
; # Redacted information, for R
; # It is \x00D7 multiplication sign, not your normal 'x'
× = ×
[FIELDS]
; # Time information for segment
time = 3
; # Word surface
surface = 5
; # Word surface root form without CSJ tags
notag = 9
; # Part Of Speech
pos1 = 11
; # Conjugated Form
cForm = 12
; # Conjugation Type
cType1 = 13
; # Subcategory of POS
pos2 = 14
; # Euphonic Change / Subcategory of Conjugation Type
cType2 = 15
; # Other information
other = 16
; # Pronunciation for lexicon
pron = 10
; # Speaker ID
spk_id = 2
[KATAKANA2ROMAJI]
= 'a
= 'i
= 'u
= 'e
= 'o
= ka
= ki
= ku
= ke
= ko
= ga
= gi
= gu
= ge
= go
= sa
= si
= su
= se
= so
= za
= zi
= zu
= ze
= zo
= ta
= ti
= tu
= te
= to
= da
= di
= du
= de
= do
= na
= ni
= nu
= ne
= no
= ha
= hi
= hu
= he
= ho
= ba
= bi
= bu
= be
= bo
= pa
= pi
= pu
= pe
= po
= ma
= mi
= mu
= me
= mo
= ya
= yu
= yo
= ra
= ri
= ru
= re
= ro
= wa
= we
= wi
= wo
= ŋ
= q
= -
キャ = kǐa
キュ = kǐu
キョ = kǐo
ギャ = gǐa
ギュ = gǐu
ギョ = gǐo
シャ = sǐa
シュ = sǐu
ショ = sǐo
ジャ = zǐa
ジュ = zǐu
ジョ = zǐo
チャ = tǐa
チュ = tǐu
チョ = tǐo
ヂャ = dǐa
ヂュ = dǐu
ヂョ = dǐo
ニャ = nǐa
ニュ = nǐu
ニョ = nǐo
ヒャ = hǐa
ヒュ = hǐu
ヒョ = hǐo
ビャ = bǐa
ビュ = bǐu
ビョ = bǐo
ピャ = pǐa
ピュ = pǐu
ピョ = pǐo
ミャ = mǐa
ミュ = mǐu
ミョ = mǐo
リャ = rǐa
リュ = rǐu
リョ = rǐo
= a
= i
= u
= e
= o
= ʍ
= vu
= ǐa
= ǐu
= ǐo

View File

@ -1,320 +1,79 @@
; # This section is ignored if this file is not supplied as the first config file to
; # lhotse prepare csj
[SEGMENTS]
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
gap = 0.5
; # Maximum length of segment (s).
maxlen = 10
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
minlen = 0.02
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
; # If you intend to use a multicharacter string for gap_sym, remember to register the
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
gap_sym =
[CONSTANTS]
; # Name of this mode
MODE = number
; # Suffixes to use after the word surface (no longer used)
MORPH = pos1 cForm cType2 pos2
; # Used to differentiate between A tag and A_num tag
JPN_NUM = ゼロ 零 一 二 三 四 五 六 七 八 九 十 百 千
; # Dummy character to delineate multiline words
PLUS =
[DECISIONS]
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
; # The PLUS (fullwidth) sign '' marks line boundaries for multiline entries
; # フィラー、感情表出系感動詞
; # 0 to remain, 1 to delete
; # Example: '(F ぎょっ)'
F = 1
; # Example: '(L (F ン))', '比べ(F えー)る'
F^ = 1
; # 言い直し、いいよどみなどによる語断片
; # 0 to remain, 1 to delete
; # Example: '(D だ)(D だいが) 大学の学部の会議'
D = 1
; # Example: '(L (D ドゥ)(D ヒ))'
D^ = 1
; # 助詞、助動詞、接辞の言い直し
; # 0 to remain, 1 to delete
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
D2 = 1
; # Example: '(X (D2 ))'
D2^ = 1
; # 聞き取りや語彙の判断に自信がない場合
; # 0 to remain, 1 to delete
; # Example: (? 字数) の
; # If no option: empty string is returned regardless of output
; # Example: '(?) で'
? = 0
; # Example: '(D (? すー))+そう+です+よ+ね'
?^ = 0
; # タグ?で、値は複数の候補が想定される場合
; # 0 for main guess with matching morph info, 1 for second guess
; # Example: '(? 次数, 実数)', '(? これ,ここで)(? 説明+し+た+方+が+いい+か+な)'
?, = 0
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
?,^ = 0
; # 音や言葉に関するメタ的な引用
; # 0 to remain, 1 to delete
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
M = 0
; # Example: '(L (M ヒ)(M ヒ))', '(L (M (? ヒ+ヒ)))'
M^ = 0
; # 外国語や古語、方言など
; # 0 to remain, 1 to delete
; # Example: '(O ザッツファイン)'
O = 0
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
O^ = 0
; # 講演者の名前、差別語、誹謗中傷など
; # 0 to remain, 1 to delete
; # Example: '国語研の (R ××) です'
R = 0
R^ = 0
; # 非朗読対象発話(朗読における言い間違い等)
; # 0 to remain, 1 to delete
; # Example: '(X 実際は) 実際には'
X = 0
; # Example: '(L (X (D2 ニ)))'
X^ = 0
; # アルファベットや算用数字、記号の表記
; # 0 to use Japanese form, 1 to use alphabet form
; # Example: '(A シーディーアール;)'
A = 1
; # Example: 'スモール(A エヌ;)', 'ラージ(A キュー;)', '(A ティーエフ;)(A アイディーエフ;)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
A^ = 1
; # タグAで、単語は算用数字の場合
; # 0 to use Japanese form, 1 to use Arabic numerals
; # Example: (A 二千;)
A_num = 1
A_num^ = 1
; # 何らかの原因で漢字表記できなくなった場合
; # 0 to use broken form, 1 to use orthodox form
; # Example: '(K たち (F えー) ばな;橘)'
K = 1
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
K^ = 1
; # 転訛、発音の怠けなど、一時的な発音エラー
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(W ギーツ;ギジュツ)'
W = 1
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
W^ = 1
; # 語の読みに関する知識レベルのいい間違い
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(B シブタイ;ジュータイ)'
B = 0
; # Example: 'データー(B カズ;スー)'
B^ = 0
; # 笑いながら発話
; # 0 to remain, 1 to delete
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
= 0
; # Example: 'コク(笑 サイ+(D オン))',
笑^ = 0
; # 泣きながら発話
; # 0 to remain, 1 to delete
; # Example: '(泣 ドンナニ)'
= 0
泣^ = 0
; # 咳をしながら発話
; # 0 to remain, 1 to delete
; # Example: 'シャ(咳 リン) '
= 0
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
咳^ = 0
; # ささやき声や独り言などの小さな声
; # 0 to remain, 1 to delete
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
L = 0
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
L^ = 0
[REPLACEMENTS]
; # ボーカルフライなどで母音が同定できない場合
<FV> =
; # 「うん/うーん/ふーん」の音の特定が困難な場合
<VN> =
; # 非語彙的な母音の引き延ばし
<H> =
; # 非語彙的な子音の引き延ばし
<Q> =
; # 言語音と独立に講演者の笑いが生じている場合
<笑> =
; # 言語音と独立に講演者の咳が生じている場合
<咳> =
; # 言語音と独立に講演者の息が生じている場合
<息> =
; # 講演者の泣き声
<泣> =
; # 聴衆(司会者なども含む)の発話
<フロア発話> =
; # 聴衆の笑い
<フロア笑> =
; # 聴衆の拍手
<拍手> =
; # 講演者が発表中に用いたデモンストレーションの音声
<デモ> =
; # 学会講演に発表時間を知らせるためにならすベルの音
<ベル> =
; # 転記単位全体が再度読み直された場合
<朗読間違い> =
; # 上記以外の音で特に目立った音
<雑音> =
; # 0.2秒以上のポーズ
<P> =
; # Redacted information, for R
; # It is \x00D7 multiplication sign, not your normal 'x'
× = ×
[FIELDS]
; # Time information for segment
time = 3
; # Word surface
surface = 5
; # Word surface root form without CSJ tags
notag = 9
; # Part Of Speech
pos1 = 11
; # Conjugated Form
cForm = 12
; # Conjugation Type
cType1 = 13
; # Subcategory of POS
pos2 = 14
; # Euphonic Change / Subcategory of Conjugation Type
cType2 = 15
; # Other information
other = 16
; # Pronunciation for lexicon
pron = 10
; # Speaker ID
spk_id = 2
[KATAKANA2ROMAJI]
= 'a
= 'i
= 'u
= 'e
= 'o
= ka
= ki
= ku
= ke
= ko
= ga
= gi
= gu
= ge
= go
= sa
= si
= su
= se
= so
= za
= zi
= zu
= ze
= zo
= ta
= ti
= tu
= te
= to
= da
= di
= du
= de
= do
= na
= ni
= nu
= ne
= no
= ha
= hi
= hu
= he
= ho
= ba
= bi
= bu
= be
= bo
= pa
= pi
= pu
= pe
= po
= ma
= mi
= mu
= me
= mo
= ya
= yu
= yo
= ra
= ri
= ru
= re
= ro
= wa
= we
= wi
= wo
= ŋ
= q
= -
キャ = kǐa
キュ = kǐu
キョ = kǐo
ギャ = gǐa
ギュ = gǐu
ギョ = gǐo
シャ = sǐa
シュ = sǐu
ショ = sǐo
ジャ = zǐa
ジュ = zǐu
ジョ = zǐo
チャ = tǐa
チュ = tǐu
チョ = tǐo
ヂャ = dǐa
ヂュ = dǐu
ヂョ = dǐo
ニャ = nǐa
ニュ = nǐu
ニョ = nǐo
ヒャ = hǐa
ヒュ = hǐu
ヒョ = hǐo
ビャ = bǐa
ビュ = bǐu
ビョ = bǐo
ピャ = pǐa
ピュ = pǐu
ピョ = pǐo
ミャ = mǐa
ミュ = mǐu
ミョ = mǐo
リャ = rǐa
リュ = rǐu
リョ = rǐo
= a
= i
= u
= e
= o
= ʍ
= vu
= ǐa
= ǐu
= ǐo

View File

@ -1,321 +1,80 @@
; # This section is ignored if this file is not supplied as the first config file to
; # lhotse prepare csj
[SEGMENTS]
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
gap = 0.5
; # Maximum length of segment (s).
maxlen = 10
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
minlen = 0.02
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
; # If you intend to use a multicharacter string for gap_sym, remember to register the
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
gap_sym =
[CONSTANTS]
; # Name of this mode
; # See https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf
; # From https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf
MODE = symbol
; # Suffixes to use after the word surface (no longer used)
MORPH = pos1 cForm cType2 pos2
; # Used to differentiate between A tag and A_num tag
JPN_NUM = ゼロ 零 一 二 三 四 五 六 七 八 九 十 百 千
; # Dummy character to delineate multiline words
PLUS =
[DECISIONS]
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
; # The PLUS (fullwidth) sign '' marks line boundaries for multiline entries
; # フィラー、感情表出系感動詞
; # 0 to remain, 1 to delete
; # Example: '(F ぎょっ)'
F =
; # Example: '(L (F ン))', '比べ(F えー)る'
F^ =
F = "", ["F"]
; # 言い直し、いいよどみなどによる語断片
; # 0 to remain, 1 to delete
; # Example: '(D だ)(D だいが) 大学の学部の会議'
D =
; # Example: '(L (D ドゥ)(D ヒ))'
D^ =
D = "", ["D"]
; # 助詞、助動詞、接辞の言い直し
; # 0 to remain, 1 to delete
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
D2 =
; # Example: '(X (D2 ))'
D2^ =
D2 = "", ["D2"]
; # 聞き取りや語彙の判断に自信がない場合
; # 0 to remain, 1 to delete
; # Example: (? 字数) の
; # If no option: empty string is returned regardless of output
; # Example: '(?) で'
? = 0
; # Example: '(D (? すー))+そう+です+よ+ね'
?^ = 0
; # タグ?で、値は複数の候補が想定される場合
; # 0 for main guess with matching morph info, 1 for second guess
; # Example: '(? 次数, 実数)', '(? これ,ここで)(? 説明+し+た+方+が+いい+か+な)'
?, = 0
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
?,^ = 0
; # 音や言葉に関するメタ的な引用
; # 0 to remain, 1 to delete
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
M = 0
; # Example: '(L (M ヒ)(M ヒ))', '(L (M (? ヒ+ヒ)))'
M^ = 0
; # 外国語や古語、方言など
; # 0 to remain, 1 to delete
; # Example: '(O ザッツファイン)'
O = 0
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
O^ = 0
; # 講演者の名前、差別語、誹謗中傷など
; # 0 to remain, 1 to delete
; # Example: '国語研の (R ××) です'
R = 0
R^ = 0
; # 非朗読対象発話(朗読における言い間違い等)
; # 0 to remain, 1 to delete
; # Example: '(X 実際は) 実際には'
X = 0
; # Example: '(L (X (D2 ニ)))'
X^ = 0
; # アルファベットや算用数字、記号の表記
; # 0 to use Japanese form, 1 to use alphabet form
; # Example: '(A シーディーアール;)'
A = 1
; # Example: 'スモール(A エヌ;)', 'ラージ(A キュー;)', '(A ティーエフ;)(A アイディーエフ;)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
A^ = 1
; # タグAで、単語は算用数字の場合
; # 0 to use Japanese form, 1 to use Arabic numerals
; # Example: (A 二千;)
A_num = eval:self.notag
A_num^ = eval:self.notag
A_num = 1
; # 何らかの原因で漢字表記できなくなった場合
; # 0 to use broken form, 1 to use orthodox form
; # Example: '(K たち (F えー) ばな;橘)'
K = 1
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
K^ = 1
; # 転訛、発音の怠けなど、一時的な発音エラー
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(W ギーツ;ギジュツ)'
W = 1
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
W^ = 1
; # 語の読みに関する知識レベルのいい間違い
; # 0 to use wrong form, 1 to use orthodox form
; # Example: '(B シブタイ;ジュータイ)'
B = 0
; # Example: 'データー(B カズ;スー)'
B^ = 0
; # 笑いながら発話
; # 0 to remain, 1 to delete
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
= 0
; # Example: 'コク(笑 サイ+(D オン))',
笑^ = 0
; # 泣きながら発話
; # 0 to remain, 1 to delete
; # Example: '(泣 ドンナニ)'
= 0
泣^ = 0
; # 咳をしながら発話
; # 0 to remain, 1 to delete
; # Example: 'シャ(咳 リン) '
= 0
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
咳^ = 0
; # ささやき声や独り言などの小さな声
; # 0 to remain, 1 to delete
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
L = 0
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
L^ = 0
[REPLACEMENTS]
; # ボーカルフライなどで母音が同定できない場合
<FV> =
; # 「うん/うーん/ふーん」の音の特定が困難な場合
<VN> =
; # 非語彙的な母音の引き延ばし
<H> =
; # 非語彙的な子音の引き延ばし
<Q> =
; # 言語音と独立に講演者の笑いが生じている場合
<笑> =
; # 言語音と独立に講演者の咳が生じている場合
<咳> =
; # 言語音と独立に講演者の息が生じている場合
<息> =
; # 講演者の泣き声
<泣> =
; # 聴衆(司会者なども含む)の発話
<フロア発話> =
; # 聴衆の笑い
<フロア笑> =
; # 聴衆の拍手
<拍手> =
; # 講演者が発表中に用いたデモンストレーションの音声
<デモ> =
; # 学会講演に発表時間を知らせるためにならすベルの音
<ベル> =
; # 転記単位全体が再度読み直された場合
<朗読間違い> =
; # 上記以外の音で特に目立った音
<雑音> =
; # 0.2秒以上のポーズ
<P> =
; # Redacted information, for R
; # It is \x00D7 multiplication sign, not your normal 'x'
× = ×
[FIELDS]
; # Time information for segment
time = 3
; # Word surface
surface = 5
; # Word surface root form without CSJ tags
notag = 9
; # Part Of Speech
pos1 = 11
; # Conjugated Form
cForm = 12
; # Conjugation Type
cType1 = 13
; # Subcategory of POS
pos2 = 14
; # Euphonic Change / Subcategory of Conjugation Type
cType2 = 15
; # Other information
other = 16
; # Pronunciation for lexicon
pron = 10
; # Speaker ID
spk_id = 2
[KATAKANA2ROMAJI]
= 'a
= 'i
= 'u
= 'e
= 'o
= ka
= ki
= ku
= ke
= ko
= ga
= gi
= gu
= ge
= go
= sa
= si
= su
= se
= so
= za
= zi
= zu
= ze
= zo
= ta
= ti
= tu
= te
= to
= da
= di
= du
= de
= do
= na
= ni
= nu
= ne
= no
= ha
= hi
= hu
= he
= ho
= ba
= bi
= bu
= be
= bo
= pa
= pi
= pu
= pe
= po
= ma
= mi
= mu
= me
= mo
= ya
= yu
= yo
= ra
= ri
= ru
= re
= ro
= wa
= we
= wi
= wo
= ŋ
= q
= -
キャ = kǐa
キュ = kǐu
キョ = kǐo
ギャ = gǐa
ギュ = gǐu
ギョ = gǐo
シャ = sǐa
シュ = sǐu
ショ = sǐo
ジャ = zǐa
ジュ = zǐu
ジョ = zǐo
チャ = tǐa
チュ = tǐu
チョ = tǐo
ヂャ = dǐa
ヂュ = dǐu
ヂョ = dǐo
ニャ = nǐa
ニュ = nǐu
ニョ = nǐo
ヒャ = hǐa
ヒュ = hǐu
ヒョ = hǐo
ビャ = bǐa
ビュ = bǐu
ビョ = bǐo
ピャ = pǐa
ピュ = pǐu
ピョ = pǐo
ミャ = mǐa
ミュ = mǐu
ミョ = mǐo
リャ = rǐa
リュ = rǐu
リョ = rǐo
= a
= i
= u
= e
= o
= ʍ
= vu
= ǐa
= ǐu
= ǐo

View File

@ -0,0 +1,202 @@
import argparse
from pathlib import Path
import kaldialign
from lhotse import CutSet
ARGPARSE_DESCRIPTION = """
This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript,
compares it against a fluent transcript, and saves the results in a separate directory.
This is useful to compare disfluent models with fluent models on the same metric.
"""
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=ARGPARSE_DESCRIPTION,
)
parser.add_argument(
"--recogs",
type=Path,
required=True,
help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.",
)
parser.add_argument(
"--cut",
type=Path,
required=True,
help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.",
)
parser.add_argument(
"--res-dir", type=Path, required=True, help="Path to save results"
)
return parser.parse_args()
def d2f(stats):
"""
Compare the outputs of a disfluent model against a fluent reference.
Indicates a disfluent model's performance only on the content words
CER^d_f = (sub_f + ins + del_f) / Nf
"""
return stats["base"] / stats["Nf"]
def calc_cer(refs, hyps):
subs = {
"F": 0,
"D": 0,
}
ins = 0
dels = {
"F": 0,
"D": 0,
}
cors = {
"F": 0,
"D": 0,
}
dis_ref_len = 0
flu_ref_len = 0
for ref, hyp in zip(refs, hyps):
assert (
ref[0] == hyp[0]
), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}."
tag = ref[2].copy()
ref = ref[1]
dis_ref_len += len(ref)
# Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively.
flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)])
hyp = hyp[1]
ali = kaldialign.align(ref, hyp, "*")
tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali]
for tag, (ref_word, hyp_word) in zip(tags, ali):
if "D" in tag or "F" in tag:
tag = "D"
else:
tag = "F"
if ref_word == "*":
ins += 1
elif hyp_word == "*":
dels[tag] += 1
elif ref_word != hyp_word:
subs[tag] += 1
else:
cors[tag] += 1
return {
"subs": subs,
"ins": ins,
"dels": dels,
"cors": cors,
"dis_ref_len": dis_ref_len,
"flu_ref_len": flu_ref_len,
}
def for_each_recogs(recogs_file: Path, refs, out_dir):
hyps = []
with recogs_file.open() as fin:
for line in fin:
if "ref" in line:
continue
cutid, hyp = line.split(":\thyp=")
hyps.append((cutid, eval(hyp)))
assert len(refs) == len(
hyps
), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal."
stats = calc_cer(refs, hyps)
stat_table = ["tag,yes,no"]
for cer_type in ["subs", "dels", "cors", "ins"]:
ret = f"{cer_type}"
for df in ["D", "F"]:
try:
ret += f",{stats[cer_type][df]}"
except TypeError:
# insertions do not belong to F or D, and is not subscriptable.
ret += f",{stats[cer_type]},"
break
stat_table.append(ret)
stat_table = "\n".join(stat_table)
stats = {
"subd": stats["subs"]["D"],
"deld": stats["dels"]["D"],
"cord": stats["cors"]["D"],
"Nf": stats["flu_ref_len"],
"base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"],
}
cer = d2f(stats)
results = [
f"{cer:.2%}",
f"Nf,{stats['Nf']}",
]
results = "\n".join(results)
with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout:
fout.write(results)
fout.write("\n\n")
fout.write(stat_table)
def main():
args = get_args()
recogs_file: Path = args.recogs
assert (
recogs_file.is_file() or recogs_file.is_dir()
), f"recogs_file cannot be found at {recogs_file}."
args.res_dir.mkdir(parents=True, exist_ok=True)
if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"):
assert (
"csj_cuts" in args.cut.name
), f"Expected {args.cut} to be a cuts manifest."
refs: CutSet = CutSet.from_file(args.cut)
refs = sorted(
[
(
e.id,
list(e.supervisions[0].custom["disfluent"]),
e.supervisions[0].custom["disfluent_tag"].split(","),
)
for e in refs
],
key=lambda x: x[0],
)
for_each_recogs(recogs_file, refs, args.res_dir)
elif recogs_file.is_dir():
recogs_file_path = recogs_file
for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]:
refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz")
refs = sorted(
[
(
r.id,
list(r.supervisions[0].custom["disfluent"]),
r.supervisions[0].custom["disfluent_tag"].split(","),
)
for r in refs
],
key=lambda x: x[0],
)
for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"):
for_each_recogs(recogs_file, refs, args.res_dir)
else:
raise TypeError(f"Unrecognised recogs file provided: {recogs_file}")
if __name__ == "__main__":
main()

View File

@ -45,8 +45,8 @@ def get_parser():
def main():
args = get_parser()
for path in args.manifest_dir.glob("csj_cuts_*.jsonl.gz"):
for part in ["eval1", "eval2", "eval3", "valid", "excluded", "train"]:
path = args.manifest_dir / f"csj_cuts_{part}.jsonl.gz"
cuts: CutSet = load_manifest(path)
print("\n---------------------------------\n")
@ -58,123 +58,271 @@ if __name__ == "__main__":
main()
"""
## eval1
Cuts count: 1272
Total duration (hh:mm:ss): 01:50:07
Speech duration (hh:mm:ss): 01:50:07 (100.0%)
Duration statistics (seconds):
mean 5.2
std 3.9
min 0.2
25% 1.9
50% 4.0
75% 8.1
99% 14.3
99.5% 14.7
99.9% 16.0
max 16.9
Recordings available: 1272
Features available: 1272
Supervisions available: 1272
csj_cuts_eval1.jsonl.gz:
Cut statistics:
Cuts count: 1023
Total duration (hh:mm:ss) 01:55:40
mean 6.8
std 2.7
min 0.2
25% 4.9
50% 7.7
75% 9.0
99% 10.0
99.5% 10.0
99.9% 10.0
max 10.0
Recordings available: 1023
Features available: 0
Supervisions available: 1023
SUPERVISION custom fields:
- fluent (in 1272 cuts)
- disfluent (in 1272 cuts)
- number (in 1272 cuts)
- symbol (in 1272 cuts)
Speech duration statistics:
Total speech duration 01:55:40 100.00% of recording
Total speaking time duration 01:55:40 100.00% of recording
Total silence duration 00:00:00 0.00% of recording
## eval2
Cuts count: 1292
Total duration (hh:mm:ss): 01:56:50
Speech duration (hh:mm:ss): 01:56:50 (100.0%)
Duration statistics (seconds):
mean 5.4
std 3.9
min 0.1
25% 2.1
50% 4.6
75% 8.6
99% 14.1
99.5% 15.2
99.9% 16.1
max 16.9
Recordings available: 1292
Features available: 1292
Supervisions available: 1292
SUPERVISION custom fields:
- fluent (in 1292 cuts)
- number (in 1292 cuts)
- symbol (in 1292 cuts)
- disfluent (in 1292 cuts)
---------------------------------
## eval3
Cuts count: 1385
Total duration (hh:mm:ss): 01:19:21
Speech duration (hh:mm:ss): 01:19:21 (100.0%)
Duration statistics (seconds):
mean 3.4
std 3.0
min 0.2
25% 1.2
50% 2.5
75% 4.6
99% 12.7
99.5% 13.7
99.9% 15.0
max 15.9
Recordings available: 1385
Features available: 1385
Supervisions available: 1385
csj_cuts_eval2.jsonl.gz:
Cut statistics:
Cuts count: 1025
Total duration (hh:mm:ss) 02:02:07
mean 7.1
std 2.5
min 0.1
25% 5.9
50% 7.9
75% 9.1
99% 10.0
99.5% 10.0
99.9% 10.0
max 10.0
Recordings available: 1025
Features available: 0
Supervisions available: 1025
SUPERVISION custom fields:
- number (in 1385 cuts)
- symbol (in 1385 cuts)
- fluent (in 1385 cuts)
- disfluent (in 1385 cuts)
Speech duration statistics:
Total speech duration 02:02:07 100.00% of recording
Total speaking time duration 02:02:07 100.00% of recording
Total silence duration 00:00:00 0.00% of recording
## valid
Cuts count: 4000
Total duration (hh:mm:ss): 05:08:09
Speech duration (hh:mm:ss): 05:08:09 (100.0%)
Duration statistics (seconds):
mean 4.6
std 3.8
min 0.1
25% 1.5
50% 3.4
75% 7.0
99% 13.8
99.5% 14.8
99.9% 16.0
max 17.3
Recordings available: 4000
Features available: 4000
Supervisions available: 4000
SUPERVISION custom fields:
- fluent (in 4000 cuts)
- symbol (in 4000 cuts)
- disfluent (in 4000 cuts)
- number (in 4000 cuts)
---------------------------------
## train
Cuts count: 1291134
Total duration (hh:mm:ss): 1596:37:27
Speech duration (hh:mm:ss): 1596:37:27 (100.0%)
Duration statistics (seconds):
mean 4.5
std 3.6
min 0.0
25% 1.6
50% 3.3
75% 6.4
99% 14.0
99.5% 14.8
99.9% 16.6
max 27.8
Recordings available: 1291134
Features available: 1291134
Supervisions available: 1291134
csj_cuts_eval3.jsonl.gz:
Cut statistics:
Cuts count: 865
Total duration (hh:mm:ss) 01:26:44
mean 6.0
std 3.0
min 0.3
25% 3.3
50% 6.8
75% 8.7
99% 10.0
99.5% 10.0
99.9% 10.0
max 10.0
Recordings available: 865
Features available: 0
Supervisions available: 865
SUPERVISION custom fields:
- disfluent (in 1291134 cuts)
- fluent (in 1291134 cuts)
- symbol (in 1291134 cuts)
- number (in 1291134 cuts)
Speech duration statistics:
Total speech duration 01:26:44 100.00% of recording
Total speaking time duration 01:26:44 100.00% of recording
Total silence duration 00:00:00 0.00% of recording
---------------------------------
csj_cuts_valid.jsonl.gz:
Cut statistics:
Cuts count: 3743
Total duration (hh:mm:ss) 06:40:15
mean 6.4
std 3.0
min 0.1
25% 3.9
50% 7.4
75% 9.0
99% 10.0
99.5% 10.0
99.9% 10.1
max 11.8
Recordings available: 3743
Features available: 0
Supervisions available: 3743
SUPERVISION custom fields:
Speech duration statistics:
Total speech duration 06:40:15 100.00% of recording
Total speaking time duration 06:40:15 100.00% of recording
Total silence duration 00:00:00 0.00% of recording
---------------------------------
csj_cuts_excluded.jsonl.gz:
Cut statistics:
Cuts count: 980
Total duration (hh:mm:ss) 00:56:06
mean 3.4
std 3.1
min 0.1
25% 0.8
50% 2.2
75% 5.8
99% 9.9
99.5% 9.9
99.9% 10.0
max 10.0
Recordings available: 980
Features available: 0
Supervisions available: 980
SUPERVISION custom fields:
Speech duration statistics:
Total speech duration 00:56:06 100.00% of recording
Total speaking time duration 00:56:06 100.00% of recording
Total silence duration 00:00:00 0.00% of recording
---------------------------------
csj_cuts_train.jsonl.gz:
Cut statistics:
Cuts count: 914151
Total duration (hh:mm:ss) 1695:29:43
mean 6.7
std 2.9
min 0.1
25% 4.6
50% 7.5
75% 8.9
99% 11.0
99.5% 11.0
99.9% 11.1
max 18.0
Recordings available: 914151
Features available: 0
Supervisions available: 914151
SUPERVISION custom fields:
Speech duration statistics:
Total speech duration 1695:29:43 100.00% of recording
Total speaking time duration 1695:29:43 100.00% of recording
Total silence duration 00:00:00 0.00% of recording
"""

View File

@ -21,24 +21,14 @@ import logging
from pathlib import Path
from lhotse import CutSet
from lhotse.recipes.csj import CSJSDBParser
ARGPARSE_DESCRIPTION = """
This script gathers all training transcripts of the specified {trans_mode} type
and produces a token_list that would be output set of the ASR system.
This script gathers all training transcripts, parses them in disfluent mode, and produces a token list that would be the output set of the ASR system.
It splits transcripts by whitespace into lists, then, for each word in the
list, if the word does not appear in the list of user-defined multicharacter
strings, it further splits that word into individual characters to be counted
into the output token set.
It outputs 4 files into the lang directory:
- trans_mode: the name of transcript mode. If trans_mode was not specified,
this will be an empty file.
- userdef_string: a list of user defined strings that should not be split
further into individual characters. By default, it contains "<unk>", "<blk>",
"<sos/eos>"
- words_len: the total number of tokens in the output set.
- words.txt: a list of tokens in the output set. The length matches words_len.
It outputs 3 files into the lang directory:
- tokens.txt: a list of tokens in the output set.
- lang_type: a file that contains the string "char"
"""
@ -50,98 +40,52 @@ def get_args():
)
parser.add_argument(
"--train-cut", type=Path, required=True, help="Path to the train cut"
)
parser.add_argument(
"--trans-mode",
type=str,
default=None,
help=(
"Name of the transcript mode to use. "
"If lang-dir is not set, this will also name the lang-dir"
),
"train_cut", metavar="train-cut", type=Path, help="Path to the train cut"
)
parser.add_argument(
"--lang-dir",
type=Path,
default=None,
default=Path("data/lang_char"),
help=(
"Name of lang dir. "
"If not set, this will default to lang_char_{trans-mode}"
),
)
parser.add_argument(
"--userdef-string",
type=Path,
default=None,
help="Multicharacter strings that do not need to be split",
)
return parser.parse_args()
def main():
args = get_args()
logging.basicConfig(
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
level=logging.INFO,
)
if not args.lang_dir:
p = "lang_char"
if args.trans_mode:
p += f"_{args.trans_mode}"
args.lang_dir = Path(p)
sysdef_string = set(["<blk>", "<unk>", "<sos/eos>"])
if args.userdef_string:
args.userdef_string = set(args.userdef_string.read_text().split())
else:
args.userdef_string = set()
# Using disfluent parsing as fluent is a subset of disfluent
parser = CSJSDBParser()
sysdef_string = ["<blk>", "<unk>", "<sos/eos>"]
args.userdef_string.update(sysdef_string)
token_set = set()
logging.info(f"Creating vocabulary from {args.train_cut}.")
train_cut: CutSet = CutSet.from_file(args.train_cut)
for cut in train_cut:
if "_sp" in cut.id:
continue
train_set: CutSet = CutSet.from_file(args.train_cut)
words = set()
logging.info(
f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode."
)
for cut in train_set:
try:
text: str = (
cut.supervisions[0].custom[args.trans_mode]
if args.trans_mode
else cut.supervisions[0].text
)
except KeyError:
raise KeyError(
f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}"
)
for t in text.split():
if t in args.userdef_string:
words.add(t)
else:
words.update(c for c in list(t))
words -= set(sysdef_string)
words = sorted(words)
words = ["<blk>"] + words + ["<unk>", "<sos/eos>"]
text: str = cut.supervisions[0].custom["raw"]
for w in parser.parse(text, sep=" ").split(" "):
token_set.update(w)
token_set = ["<blk>"] + sorted(token_set - sysdef_string) + ["<unk>", "<sos/eos>"]
args.lang_dir.mkdir(parents=True, exist_ok=True)
(args.lang_dir / "words.txt").write_text(
"\n".join(f"{word}\t{i}" for i, word in enumerate(words))
(args.lang_dir / "tokens.txt").write_text(
"\n".join(f"{t}\t{i}" for i, t in enumerate(token_set))
)
(args.lang_dir / "words_len").write_text(f"{len(words)}")
(args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string))
(args.lang_dir / "trans_mode").write_text(args.trans_mode)
(args.lang_dir / "lang_type").write_text("char")
logging.info("Done.")

View File

@ -0,0 +1,462 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
OnTheFlyFeatures,
)
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrVariableTranscriptDataset(K2SpeechRecognitionDataset):
def __init__(
self,
*args,
transcript_mode: str = "",
return_cuts: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.transcript_mode = transcript_mode
self.return_cuts = True
self._return_cuts = return_cuts
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
batch = super().__getitem__(cuts)
if self.transcript_mode:
batch["supervisions"]["text"] = [
supervision.custom[self.transcript_mode]
for cut in batch["supervisions"]["cut"]
for supervision in cut.supervisions
]
if not self._return_cuts:
del batch["supervisions"]["cut"]
return batch
class CSJAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--transcript-mode",
type=str,
default="",
help="Mode of transcript in supervision to use.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--musan-dir", type=Path, help="Path to directory with musan cuts. "
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = AsrVariableTranscriptDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
transcript_mode=self.args.transcript_mode,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = AsrVariableTranscriptDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
transcript_mode=self.args.transcript_mode,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = AsrVariableTranscriptDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
transcript_mode=self.args.transcript_mode,
)
else:
validate = AsrVariableTranscriptDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
transcript_mode=self.args.transcript_mode,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = AsrVariableTranscriptDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
return_cuts=self.args.return_cuts,
transcript_mode=self.args.transcript_mode,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_train.jsonl.gz")
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get valid cuts")
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_valid.jsonl.gz")
@lru_cache()
def excluded_cuts(self) -> CutSet:
logging.info("About to get excluded cuts")
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_excluded.jsonl.gz")
@lru_cache()
def eval1_cuts(self) -> CutSet:
logging.info("About to get eval1 cuts")
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval1.jsonl.gz")
@lru_cache()
def eval2_cuts(self) -> CutSet:
logging.info("About to get eval2 cuts")
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval2.jsonl.gz")
@lru_cache()
def eval3_cuts(self) -> CutSet:
logging.info("About to get eval3 cuts")
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval3.jsonl.gz")

View File

@ -0,0 +1,253 @@
import argparse
from pathlib import Path
from typing import Callable, List, Union
import sentencepiece as spm
from k2 import SymbolTable
class Tokenizer:
text2word: Callable[[str], List[str]]
@staticmethod
def add_arguments(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Lang related options")
group.add_argument("--lang", type=Path, help="Path to lang directory.")
group.add_argument(
"--lang-type",
type=str,
default=None,
help=(
"Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. "
"Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor"
),
)
@staticmethod
def Load(lang_dir: Path, lang_type="", oov="<unk>"):
if not lang_type:
assert (lang_dir / "lang_type").exists(), "lang_type not specified."
lang_type = (lang_dir / "lang_type").read_text().strip()
tokenizer = None
if lang_type == "bpe":
assert (
lang_dir / "bpe.model"
).exists(), f"No BPE .model could be found in {lang_dir}."
tokenizer = spm.SentencePieceProcessor()
tokenizer.Load(str(lang_dir / "bpe.model"))
elif lang_type == "char":
tokenizer = CharTokenizer(lang_dir, oov=oov)
else:
raise NotImplementedError(f"{lang_type} not supported at the moment.")
return tokenizer
load = Load
def PieceToId(self, piece: str) -> int:
raise NotImplementedError(
"You need to implement this function in the child class."
)
piece_to_id = PieceToId
def IdToPiece(self, id: int) -> str:
raise NotImplementedError(
"You need to implement this function in the child class."
)
id_to_piece = IdToPiece
def GetPieceSize(self) -> int:
raise NotImplementedError(
"You need to implement this function in the child class."
)
get_piece_size = GetPieceSize
def __len__(self) -> int:
return self.get_piece_size()
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def EncodeAsIds(self, input: str) -> List[int]:
return self.EncodeAsIdsBatch([input])[0]
def EncodeAsPieces(self, input: str) -> List[str]:
return self.EncodeAsPiecesBatch([input])[0]
def Encode(
self, input: Union[str, List[str]], out_type=int
) -> Union[List, List[List]]:
if not input:
return []
if isinstance(input, list):
if out_type is int:
return self.EncodeAsIdsBatch(input)
if out_type is str:
return self.EncodeAsPiecesBatch(input)
if out_type is int:
return self.EncodeAsIds(input)
if out_type is str:
return self.EncodeAsPieces(input)
encode = Encode
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def DecodeIds(self, input: List[int]) -> str:
return self.DecodeIdsBatch([input])[0]
def DecodePieces(self, input: List[str]) -> str:
return self.DecodePiecesBatch([input])[0]
def Decode(
self,
input: Union[int, List[int], List[str], List[List[int]], List[List[str]]],
) -> Union[List[str], str]:
if not input:
return ""
if isinstance(input, int):
return self.id_to_piece(input)
elif isinstance(input, str):
raise TypeError(
"Unlike spm.SentencePieceProcessor, cannot decode from type str."
)
if isinstance(input[0], list):
if not input[0] or isinstance(input[0][0], int):
return self.DecodeIdsBatch(input)
if isinstance(input[0][0], str):
return self.DecodePiecesBatch(input)
if isinstance(input[0], int):
return self.DecodeIds(input)
if isinstance(input[0], str):
return self.DecodePieces(input)
raise RuntimeError("Unknown input type")
decode = Decode
def SplitBatch(self, input: List[str]) -> List[List[str]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]:
if isinstance(input, list):
return self.SplitBatch(input)
elif isinstance(input, str):
return self.SplitBatch([input])[0]
raise RuntimeError("Unknown input type")
split = Split
class CharTokenizer(Tokenizer):
def __init__(self, lang_dir: Path, oov="<unk>", sep=""):
assert (
lang_dir / "tokens.txt"
).exists(), f"tokens.txt could not be found in {lang_dir}."
token_table = SymbolTable.from_file(lang_dir / "tokens.txt")
assert (
"#0" not in token_table
), "This tokenizer does not support disambig symbols."
self._id2sym = token_table._id2sym
self._sym2id = token_table._sym2id
self.oov = oov
self.oov_id = self._sym2id[oov]
self.sep = sep
if self.sep:
self.text2word = lambda x: x.split(self.sep)
else:
self.text2word = lambda x: list(x.replace(" ", ""))
def piece_to_id(self, piece: str) -> int:
try:
return self._sym2id[piece]
except KeyError:
return self.oov_id
def id_to_piece(self, id: int) -> str:
return self._id2sym[id]
def get_piece_size(self) -> int:
return len(self._sym2id)
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input]
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
return [
[i if i in self._sym2id else self.oov for i in self.text2word(text)]
for text in input
]
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
return [self.sep.join(self.id_to_piece(i) for i in text) for text in input]
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
return [self.sep.join(text) for text in input]
def SplitBatch(self, input: List[str]) -> List[List[str]]:
return [self.text2word(text) for text in input]
def test_CharTokenizer():
test_single_string = "こんにちは"
test_multiple_string = [
"今日はいい天気ですよね",
"諏訪湖は綺麗でしょう",
"这在词表外",
"分かち 書き に し た 文章 です",
"",
]
test_empty_string = ""
sp = Tokenizer.load(Path("lang_char"), "char", oov="<unk>")
splitter = sp.split
print(sp.encode(test_single_string, out_type=str))
print(sp.encode(test_single_string, out_type=int))
print(sp.encode(test_multiple_string, out_type=str))
print(sp.encode(test_multiple_string, out_type=int))
print(sp.encode(test_empty_string, out_type=str))
print(sp.encode(test_empty_string, out_type=int))
print(sp.decode(sp.encode(test_single_string, out_type=str)))
print(sp.decode(sp.encode(test_single_string, out_type=int)))
print(sp.decode(sp.encode(test_multiple_string, out_type=str)))
print(sp.decode(sp.encode(test_multiple_string, out_type=int)))
print(sp.decode(sp.encode(test_empty_string, out_type=str)))
print(sp.decode(sp.encode(test_empty_string, out_type=int)))
print(splitter(test_single_string))
print(splitter(test_multiple_string))
print(splitter(test_empty_string))
if __name__ == "__main__":
test_CharTokenizer()

View File

@ -32,7 +32,7 @@
# - speech
#
# By default, this script produces the original transcript like kaldi and espnet. Optionally, you
# can generate other transcript formats by supplying your own config files. A few examples of these
# can add other transcript formats by supplying your own config files. A few examples of these
# config files can be found in local/conf.
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
@ -44,10 +44,10 @@ nj=8
stage=-1
stop_stage=100
csj_dir=/mnt/minami_data_server/t2131178/corpus/CSJ
musan_dir=/mnt/minami_data_server/t2131178/corpus/musan/musan
trans_dir=$csj_dir/retranscript
csj_fbank_dir=/mnt/host/csj_data/fbank
csj_dir=/mnt/host/corpus/csj
musan_dir=/mnt/host/corpus/musan/musan
trans_dir=$csj_dir/transcript
csj_fbank_dir=/mnt/host/corpus/csj/fbank
musan_fbank_dir=$musan_dir/fbank
csj_manifest_dir=data/manifests
musan_manifest_dir=$musan_dir/manifests
@ -63,12 +63,8 @@ log() {
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare CSJ manifest"
# If you want to generate more transcript modes, append the path to those config files at c.
# Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini
# NOTE: In case multiple config files are supplied, the second config file and onwards will inherit
# the segment boundaries of the first config file.
if [ ! -e $csj_manifest_dir/.csj.done ]; then
lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4
lhotse prepare csj $csj_dir $csj_manifest_dir -t $trans_dir -j 16
touch $csj_manifest_dir/.csj.done
fi
fi
@ -88,32 +84,24 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \
--fbank-dir $csj_fbank_dir
parts=(
train
valid
eval1
eval2
eval3
valid
excluded
train
)
for part in ${parts[@]}; do
python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz
python local/validate_manifest.py --manifest $csj_fbank_dir/csj_cuts_$part.jsonl.gz
done
touch $csj_fbank_dir/.csj-validated.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Prepare CSJ lang"
modes=disfluent
# If you want prepare the lang directory for other transcript modes, just append
# the names of those modes behind. An example is shown as below:-
# modes="$modes fluent symbol number"
for mode in ${modes[@]}; do
python local/prepare_lang_char.py --trans-mode $mode \
--train-cut $csj_manifest_dir/csj_cuts_train.jsonl.gz \
--lang-dir lang_char_$mode
done
log "Stage 4: Prepare CSJ lang_char"
python local/prepare_lang_char.py $csj_fbank_dir/csj_cuts_train.jsonl.gz
python local/add_transcript_mode.py -f $csj_fbank_dir -c local/conf/fluent.ini local/conf/number.ini
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
@ -128,6 +116,6 @@ fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Show manifest statistics"
python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt
cat $csj_manifest_dir/manifest_statistics.txt
python local/display_manifest_statistics.py --manifest-dir $csj_fbank_dir > $csj_fbank_dir/manifest_statistics.txt
cat $csj_fbank_dir/manifest_statistics.txt
fi

View File

@ -0,0 +1,76 @@
import logging
from configparser import ConfigParser
import requests
def escape_html(text: str):
"""
Escapes all html characters in text
:param str text:
:rtype: str
"""
return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
class TelegramStreamIO(logging.Handler):
API_ENDPOINT = "https://api.telegram.org"
MAX_MESSAGE_LEN = 4096
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s at %(funcName)s "
"(line %(lineno)s):\n\n%(message)s"
)
def __init__(self, tg_configfile: str):
super(TelegramStreamIO, self).__init__()
config = ConfigParser()
if not config.read(tg_configfile):
raise FileNotFoundError(
f"{tg_configfile} not found. " "Retry without --telegram-cred flag."
)
config = config["TELEGRAM"]
token = config["token"]
self.chat_id = config["chat_id"]
self.url = f"{self.API_ENDPOINT}/bot{token}/sendMessage"
@staticmethod
def setup_logger(params):
if not params.telegram_cred:
return
formatter = logging.Formatter(
f"{params.exp_dir.name} %(asctime)s \n%(message)s"
)
tg = TelegramStreamIO(params.telegram_cred)
tg.setLevel(logging.WARN)
tg.setFormatter(formatter)
logging.getLogger("").addHandler(tg)
def emit(self, record: logging.LogRecord):
"""
Emit a record.
Send the record to the Web server as a percent-encoded dictionary
"""
data = {
"chat_id": self.chat_id,
"text": self.format(self.mapLogRecord(record)),
"parse_mode": "HTML",
}
try:
requests.get(self.url, json=data)
# return response.json()
except Exception as e:
logging.error(f"Failed to send telegram message: {repr(e)}")
pass
def mapLogRecord(self, record):
"""
Default implementation of mapping the log record into a dict
that is sent as the CGI data. Overwrite in your class.
Contributed by Franz Glasner.
"""
for k, v in record.__dict__.items():
if isinstance(v, str):
setattr(record, k, escape_html(v))
return record

View File

@ -0,0 +1 @@
../local/utils/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py

View File

@ -0,0 +1,852 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--lang data/lang_char \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method beam_search \
--lang data/lang_char \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method modified_beam_search \
--lang data/lang_char \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--lang data/lang_char \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--lang data/lang_char \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--lang data/lang_char \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--lang data/lang_char \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import CSJAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from tokenizer import Tokenizer
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="The experiment dir",
)
parser.add_argument(
"--res-dir",
type=Path,
default=None,
help="The path to save results.",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir. It should contain at least a word table.",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--decoding-graph",
type=str,
default="",
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--pad-feature",
type=int,
default=30,
help="""
Number of frames to pad at the end.
""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: Tokenizer,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.pad_feature:
feature_lens += params.pad_feature
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.pad_feature),
value=LOG_EPS,
)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(sp.text2word(hyp))
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.text2word(sp.decode(hyp)))
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: Tokenizer,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = sp.text2word(ref_text)
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
return test_set_wers
@torch.no_grad()
def main():
parser = get_parser()
CSJAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
if not params.res_dir:
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", params.gpu)
logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> are defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
model.encoder.decode_chunk_size,
params.decode_chunk_len,
)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
decoding_graph = None
word_table = None
if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device)
)
elif "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
csj_corpus = CSJAsrDataModule(args)
for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]:
results_dict = decode_dataset(
dl=csj_corpus.test_dataloaders(getattr(csj_corpus, f"{subdir}_cuts")()),
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
tot_err = save_results(
params=params,
test_set_name=subdir,
results_dict=results_dict,
)
with (
params.res_dir
/ (
f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
f"_{params.avg}_{params.epoch}.cer"
)
).open("w") as fout:
if len(tot_err) == 1:
fout.write(f"{tot_err[0][1]}")
else:
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py

View File

@ -0,0 +1,313 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--lang data/lang_char \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--lang data/lang_char \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/csj/ASR
./pruned_transducer_stateless7_streaming/decode.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--lang data/lang_char
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208
# You will find the pre-trained model in icefall-asr-csj-pruned-transducer-stateless7-230208/exp_fluent
"""
import argparse
import logging
from pathlib import Path
import torch
from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_streaming/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
parser = get_parser()
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> is defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,308 @@
#!/usr/bin/env python3
"""
Usage:
# use -O to skip assertions and avoid some of the TracerWarnings
python -O pruned_transducer_stateless7_streaming/jit_trace_export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--lang data/lang_char \
--epoch 30 \
--avg 10 \
--use-averaged-model=True \
--decode-chunk-len 32
"""
import argparse
import logging
from pathlib import Path
import torch
from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
add_model_arguments(parser)
return parser
def export_encoder_model_jit_trace(
encoder_model: torch.nn.Module,
encoder_filename: str,
params: AttributeDict,
) -> None:
"""Export the given encoder model with torch.jit.trace()
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported model.
"""
decode_chunk_len = params.decode_chunk_len # before subsampling
pad_length = 7
s = f"decode_chunk_len: {decode_chunk_len}"
logging.info(s)
assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
encoder_model.decode_chunk_size,
decode_chunk_len,
)
T = decode_chunk_len + pad_length
x = torch.zeros(1, T, 80, dtype=torch.float32)
x_lens = torch.full((1,), T, dtype=torch.int32)
states = encoder_model.get_init_state(device=x.device)
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
traced_model.save(encoder_filename)
logging.info(f"Saved to {encoder_filename}")
def export_decoder_model_jit_trace(
decoder_model: torch.nn.Module,
decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The input decoder model
decoder_filename:
The filename to save the exported model.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
traced_model.save(decoder_filename)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_jit_trace(
joiner_model: torch.nn.Module,
joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
Args:
joiner_model:
The input joiner model
joiner_filename:
The filename to save the exported model.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
traced_model.save(joiner_filename)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
parser = get_parser()
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
logging.info(f"device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> is defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.trace()")
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
export_decoder_model_jit_trace(model.decoder, decoder_filename)
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
export_joiner_model_jit_trace(model.joiner, joiner_filename)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,286 @@
#!/usr/bin/env python3
# flake8: noqa
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models exported by `torch.jit.trace()`
and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--lang data/lang_char \
--epoch 30 \
--avg 10 \
--use-averaged-model=True \
--decode-chunk-len 32
Usage of this script:
./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
--encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \
--decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \
--joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \
--lang data/lang_char \
--decode-chunk-len 32 \
/path/to/foo.wav \
"""
import argparse
import logging
from typing import List, Optional
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from tokenizer import Tokenizer
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder torchscript model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder torchscript model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner torchscript model. ",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--decode-chunk-len",
type=int,
default=32,
help="The chunk size for decoding (in frames before subsampling)",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
def greedy_search(
decoder: torch.jit.ScriptModule,
joiner: torch.jit.ScriptModule,
encoder_out: torch.Tensor,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
):
assert encoder_out.ndim == 2
context_size = 2
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)
# decoder_input.shape (1,, 1 context_size)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
else:
assert decoder_out.ndim == 2
assert hyp is not None, hyp
T = encoder_out.size(0)
for i in range(T):
cur_encoder_out = encoder_out[i : i + 1]
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
return hyp, decoder_out
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
@torch.no_grad()
def main():
parser = get_parser()
Tokenizer.add_arguments(parser)
args = parser.parse_args()
logging.info(vars(args))
device = torch.device("cpu")
logging.info(f"device: {device}")
encoder = torch.jit.load(args.encoder_model_filename)
decoder = torch.jit.load(args.decoder_model_filename)
joiner = torch.jit.load(args.joiner_model_filename)
encoder.eval()
decoder.eval()
joiner.eval()
encoder.to(device)
decoder.to(device)
joiner.to(device)
sp = Tokenizer.load(args.lang, args.lang_type)
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor(args.sample_rate)
logging.info(f"Reading sound files: {args.sound_file}")
wave_samples = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=args.sample_rate,
)[0]
logging.info(wave_samples.shape)
logging.info("Decoding started")
chunk_length = args.decode_chunk_len
assert encoder.decode_chunk_size == chunk_length // 2, (
encoder.decode_chunk_size,
chunk_length,
)
# we subsample features with ((x_len - 7) // 2 + 1) // 2
pad_length = 7
T = chunk_length + pad_length
logging.info(f"chunk_length: {chunk_length}")
states = encoder.get_init_state(device)
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
wave_samples = torch.cat([wave_samples, tail_padding])
chunk = int(0.25 * args.sample_rate) # 0.2 second
num_processed_frames = 0
hyp = None
decoder_out = None
start = 0
while start < wave_samples.numel():
logging.info(f"{start}/{wave_samples.numel()}")
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=args.sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= T:
frames = []
for i in range(T):
frames.append(online_fbank.get_frame(num_processed_frames + i))
frames = torch.cat(frames, dim=0).unsqueeze(0)
x_lens = torch.tensor([T], dtype=torch.int32)
encoder_out, out_lens, states = encoder(
x=frames,
x_lens=x_lens,
states=states,
)
num_processed_frames += chunk_length
hyp, decoder_out = greedy_search(
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp
)
context_size = 2
logging.info(args.sound_file)
logging.info(sp.decode(hyp[context_size:]))
logging.info("Decoding Done")
torch.set_num_threads(4)
torch.set_num_interop_threads(1)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py

View File

@ -0,0 +1,347 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--lang data/lang_char \
--epoch 20 \
--avg 10
Usage of this script:
(1) greedy search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--lang data/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--lang data/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--lang data/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless7_streaming/pretrained.py \
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
--lang data/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by
./pruned_transducer_stateless7_streaming/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
Tokenizer.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> is defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py

View File

@ -0,0 +1,597 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless7_streaming/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-len 32 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--decoding_method greedy_search \
--lang data/lang_char \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import torch
import torch.nn as nn
from asr_datamodule import CSJAsrDataModule
from decode import save_results
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states, unstack_states
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--decoding-graph",
type=str,
default="",
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
parser.add_argument(
"--res-dir",
type=Path,
default=None,
help="The path to save results.",
)
add_model_arguments(parser)
return parser
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
processed_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
# factor in encoders is 8.
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
tail_length = 23
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: Tokenizer,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = model.encoder.get_init_state(device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth),
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth),
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
@torch.no_grad()
def main():
parser = get_parser()
CSJAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if not params.res_dir:
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", params.gpu)
logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device)
)
elif params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
args.return_cuts = True
csj_corpus = CSJAsrDataModule(args)
for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]:
results_dict = decode_dataset(
cuts=getattr(csj_corpus, f"{subdir}_cuts")(),
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
tot_err = save_results(
params=params, test_set_name=subdir, results_dict=results_dict
)
with (
params.res_dir
/ (
f"{subdir}-{params.decode_chunk_len}"
f"_{params.avg}_{params.epoch}.cer"
)
).open("w") as fout:
if len(tot_err) == 1:
fout.write(f"{tot_err[0][1]}")
else:
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,150 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/csj/ASR
python ./pruned_transducer_stateless7_streaming/test_model.py
"""
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,4,3,2,4"
params.feedforward_dims = "1024,1024,2048,2048,1024"
params.nhead = "8,8,8,8,8"
params.encoder_dims = "384,384,384,384,384"
params.attention_dims = "192,192,192,192,192"
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 512
params.joiner_dim = 512
params.num_left_chunks = 4
params.short_chunk_size = 50
params.decode_chunk_len = 32
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# Test jit script
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
print("Using torch.jit.script")
model = torch.jit.script(model)
def test_model_jit_trace():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,4,3,2,4"
params.feedforward_dims = "1024,1024,2048,2048,1024"
params.nhead = "8,8,8,8,8"
params.encoder_dims = "384,384,384,384,384"
params.attention_dims = "192,192,192,192,192"
params.encoder_unmasked_dims = "256,256,256,256,256"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 512
params.joiner_dim = 512
params.num_left_chunks = 4
params.short_chunk_size = 50
params.decode_chunk_len = 32
model = get_transducer_model(params)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
convert_scaled_to_non_scaled(model, inplace=True)
# Test encoder
def _test_encoder():
encoder = model.encoder
assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
encoder.decode_chunk_size,
params.decode_chunk_len,
)
T = params.decode_chunk_len + 7
x = torch.zeros(1, T, 80, dtype=torch.float32)
x_lens = torch.full((1,), T, dtype=torch.int32)
states = encoder.get_init_state(device=x.device)
encoder.__class__.forward = encoder.__class__.streaming_forward
traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
states1 = encoder.get_init_state(device=x.device)
states2 = traced_encoder.get_init_state(device=x.device)
for i in range(5):
x = torch.randn(1, T, 80, dtype=torch.float32)
x_lens = torch.full((1,), T, dtype=torch.int32)
y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
y2, _, states2 = traced_encoder(x, x_lens, states2)
assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
# Test decoder
def _test_decoder():
decoder = model.decoder
y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
need_pad = torch.tensor([False])
traced_decoder = torch.jit.trace(decoder, (y, need_pad))
d1 = decoder(y, need_pad)
d2 = traced_decoder(y, need_pad)
assert torch.equal(d1, d2), (d1 - d2).abs().mean()
# Test joiner
def _test_joiner():
joiner = model.joiner
encoder_out_dim = joiner.encoder_proj.weight.shape[1]
decoder_out_dim = joiner.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
j1 = joiner(encoder_out, decoder_out)
j2 = traced_joiner(encoder_out, decoder_out)
assert torch.equal(j1, j2), (j1 - j2).abs().mean()
_test_encoder()
_test_decoder()
_test_joiner()
def main():
test_model()
test_model_jit_trace()
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../local/utils/tokenizer.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py