mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
CSJ pruned_transducer_stateless7_streaming (#892)
* update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * working train * Add compare_cer_transcript.py * fix tokenizer decode, allow d2f only * comment cleanup * add export files and READMEs * reword average column * fix comments * Update new results
This commit is contained in:
parent
25ee50e27c
commit
e63a8c27f8
11
egs/csj/ASR/README.md
Normal file
11
egs/csj/ASR/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
# Introduction
|
||||
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||
|
||||
# Transducers
|
||||
|
||||
These are the types of architectures currently available.
|
||||
|
||||
| | Encoder | Decoder | Comment |
|
||||
|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
|
||||
| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming |
|
200
egs/csj/ASR/RESULTS.md
Normal file
200
egs/csj/ASR/RESULTS.md
Normal file
@ -0,0 +1,200 @@
|
||||
# Results
|
||||
|
||||
## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer)
|
||||
|
||||
### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/892> for more details.
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208>
|
||||
|
||||
Number of model parameters: 75688409, i.e. 75.7M.
|
||||
|
||||
#### training on disfluent transcript
|
||||
|
||||
The CERs are:
|
||||
|
||||
| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode |
|
||||
| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- |
|
||||
| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming |
|
||||
| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise |
|
||||
| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming |
|
||||
| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise |
|
||||
| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming |
|
||||
| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise |
|
||||
| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming |
|
||||
| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise |
|
||||
| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming |
|
||||
| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise |
|
||||
| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming |
|
||||
| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise |
|
||||
|
||||
Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
|
||||
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
|
||||
|
||||
The training command was:
|
||||
```bash
|
||||
./pruned_transducer_stateless7_streaming/train.py \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--world-size 8 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
|
||||
--max-duration 375 \
|
||||
--transcript-mode disfluent \
|
||||
--lang data/lang_char \
|
||||
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||
--pad-feature 30 \
|
||||
--musan-dir /mnt/host/corpus/musan/musan/fbank
|
||||
```
|
||||
|
||||
The simulated streaming decoding command was:
|
||||
```bash
|
||||
for chunk in 64 32; do
|
||||
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||
python pruned_transducer_stateless7_streaming/decode.py \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
|
||||
--epoch 30 \
|
||||
--avg 17 \
|
||||
--max-duration 350 \
|
||||
--decoding-method $m \
|
||||
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||
--lang data/lang_char \
|
||||
--transcript-mode disfluent \
|
||||
--res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \
|
||||
--decode-chunk-len $chunk \
|
||||
--pad-feature 30 \
|
||||
--gpu 0
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
The streaming chunk-wise decoding command was:
|
||||
```bash
|
||||
for chunk in 64 32; do
|
||||
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||
python pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
|
||||
--epoch 30 \
|
||||
--avg 17 \
|
||||
--max-duration 350 \
|
||||
--decoding-method $m \
|
||||
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||
--lang data/lang_char \
|
||||
--transcript-mode disfluent \
|
||||
--res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \
|
||||
--decode-chunk-len $chunk \
|
||||
--gpu 2 \
|
||||
--num-decode-streams 40
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
#### training on fluent transcript
|
||||
|
||||
The CERs are:
|
||||
|
||||
| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode |
|
||||
| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- |
|
||||
| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming |
|
||||
| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise |
|
||||
| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming |
|
||||
| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise |
|
||||
| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming |
|
||||
| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise |
|
||||
| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming |
|
||||
| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise |
|
||||
| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming |
|
||||
| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise |
|
||||
| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming |
|
||||
| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise |
|
||||
|
||||
Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
|
||||
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
|
||||
|
||||
The training command was:
|
||||
```bash
|
||||
./pruned_transducer_stateless7_streaming/train.py \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--world-size 8 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
|
||||
--max-duration 375 \
|
||||
--transcript-mode fluent \
|
||||
--lang data/lang_char \
|
||||
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||
--pad-feature 30 \
|
||||
--musan-dir /mnt/host/corpus/musan/musan/fbank
|
||||
```
|
||||
|
||||
The simulated streaming decoding command was:
|
||||
```bash
|
||||
for chunk in 64 32; do
|
||||
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||
python pruned_transducer_stateless7_streaming/decode.py \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
|
||||
--epoch 30 \
|
||||
--avg 12 \
|
||||
--max-duration 350 \
|
||||
--decoding-method $m \
|
||||
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||
--lang data/lang_char \
|
||||
--transcript-mode fluent \
|
||||
--res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \
|
||||
--decode-chunk-len $chunk \
|
||||
--pad-feature 30 \
|
||||
--gpu 1
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
The streaming chunk-wise decoding command was:
|
||||
```bash
|
||||
for chunk in 64 32; do
|
||||
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||
python pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
|
||||
--epoch 30 \
|
||||
--avg 12 \
|
||||
--max-duration 350 \
|
||||
--decoding-method $m \
|
||||
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||
--lang data/lang_char \
|
||||
--transcript-mode fluent \
|
||||
--res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \
|
||||
--decode-chunk-len $chunk \
|
||||
--gpu 3 \
|
||||
--num-decode-streams 40
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
#### Comparing disfluent to fluent
|
||||
|
||||
$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$
|
||||
|
||||
This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared.
|
||||
|
||||
| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode |
|
||||
| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- |
|
||||
| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming |
|
||||
| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise |
|
||||
| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming |
|
||||
| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise |
|
||||
| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming |
|
||||
| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise |
|
||||
| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming |
|
||||
| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise |
|
||||
| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming |
|
||||
| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise |
|
||||
| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming |
|
||||
| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise |
|
||||
| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | |
|
94
egs/csj/ASR/local/add_transcript_mode.py
Normal file
94
egs/csj/ASR/local/add_transcript_mode.py
Normal file
@ -0,0 +1,94 @@
|
||||
import argparse
|
||||
import logging
|
||||
from configparser import ConfigParser
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from lhotse import CutSet, SupervisionSet
|
||||
from lhotse.recipes.csj import CSJSDBParser
|
||||
|
||||
ARGPARSE_DESCRIPTION = """
|
||||
This script adds transcript modes to an existing CutSet or SupervisionSet.
|
||||
"""
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
description=ARGPARSE_DESCRIPTION,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--fbank-dir",
|
||||
type=Path,
|
||||
help="Path to directory where manifests are stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--config",
|
||||
type=Path,
|
||||
nargs="+",
|
||||
help="Path to config file for transcript parsing.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]:
|
||||
parsers = []
|
||||
for config_file in config_files:
|
||||
config = ConfigParser()
|
||||
config.optionxform = str
|
||||
assert config.read(config_file), f"{config_file} could not be found."
|
||||
decisions = {}
|
||||
for k, v in config["DECISIONS"].items():
|
||||
try:
|
||||
decisions[k] = int(v)
|
||||
except ValueError:
|
||||
decisions[k] = v
|
||||
parsers.append(
|
||||
(config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions))
|
||||
)
|
||||
return parsers
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(
|
||||
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
|
||||
level=logging.INFO,
|
||||
)
|
||||
parsers = get_CSJParsers(args.config)
|
||||
config = ConfigParser()
|
||||
config.optionxform = str
|
||||
assert config.read(args.config), args.config
|
||||
decisions = {}
|
||||
for k, v in config["DECISIONS"].items():
|
||||
try:
|
||||
decisions[k] = int(v)
|
||||
except ValueError:
|
||||
decisions[k] = v
|
||||
|
||||
logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.")
|
||||
|
||||
manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz")
|
||||
assert manifests, f"No cuts to be found in {args.fbank_dir}"
|
||||
|
||||
for manifest in manifests:
|
||||
results = []
|
||||
logging.info(f"Adding transcript modes to {manifest.name} now.")
|
||||
cutset = CutSet.from_file(manifest)
|
||||
for cut in cutset:
|
||||
for name, parser in parsers:
|
||||
cut.supervisions[0].custom[name] = parser.parse(
|
||||
cut.supervisions[0].custom["raw"]
|
||||
)
|
||||
cut.supervisions[0].text = ""
|
||||
results.append(cut)
|
||||
results = CutSet.from_items(results)
|
||||
res_file = manifest.as_posix()
|
||||
manifest.replace(manifest.parent / ("bak." + manifest.name))
|
||||
results.to_file(res_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
|
||||
# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -19,9 +19,7 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from random import Random
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
@ -35,20 +33,10 @@ from lhotse import ( # See the following for why LilcomChunkyWriter is preferre
|
||||
RecordingSet,
|
||||
SupervisionSet,
|
||||
)
|
||||
from lhotse.recipes.csj import concat_csj_supervisions
|
||||
|
||||
# fmt: on
|
||||
|
||||
ARGPARSE_DESCRIPTION = """
|
||||
This script follows the espnet method of splitting the remaining core+noncore
|
||||
utterances into valid and train cutsets at an index which is by default 4000.
|
||||
|
||||
In other words, the core+noncore utterances are shuffled, where 4000 utterances
|
||||
of the shuffled set go to the `valid` cutset and are not subject to speed
|
||||
perturbation. The remaining utterances become the `train` cutset and are speed-
|
||||
perturbed (0.9x, 1.0x, 1.1x).
|
||||
|
||||
"""
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -57,66 +45,101 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
RNG_SEED = 42
|
||||
# concat_params_train = [
|
||||
# {"gap": 1.0, "maxlen": 10.0},
|
||||
# {"gap": 1.5, "maxlen": 8.0},
|
||||
# {"gap": 1.0, "maxlen": 18.0},
|
||||
# ]
|
||||
|
||||
concat_params = {"gap": 1.0, "maxlen": 10.0}
|
||||
|
||||
|
||||
def make_cutset_blueprints(
|
||||
manifest_dir: Path,
|
||||
split: int,
|
||||
) -> List[Tuple[str, CutSet]]:
|
||||
|
||||
cut_sets = []
|
||||
logging.info("Creating non-train cuts.")
|
||||
|
||||
# Create eval datasets
|
||||
logging.info("Creating eval cuts.")
|
||||
for i in range(1, 4):
|
||||
sps = sorted(
|
||||
SupervisionSet.from_file(
|
||||
manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz"
|
||||
),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=RecordingSet.from_file(
|
||||
manifest_dir / f"csj_recordings_eval{i}.jsonl.gz"
|
||||
),
|
||||
supervisions=SupervisionSet.from_file(
|
||||
manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz"
|
||||
),
|
||||
supervisions=concat_csj_supervisions(sps, **concat_params),
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
||||
cut_sets.append((f"eval{i}", cut_set))
|
||||
|
||||
# Create train and valid cuts
|
||||
logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
|
||||
recording_set = RecordingSet.from_file(
|
||||
manifest_dir / "csj_recordings_core.jsonl.gz"
|
||||
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
|
||||
supervision_set = SupervisionSet.from_file(
|
||||
manifest_dir / "csj_supervisions_core.jsonl.gz"
|
||||
) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
|
||||
|
||||
# Create excluded dataset
|
||||
sps = sorted(
|
||||
SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=recording_set,
|
||||
supervisions=supervision_set,
|
||||
recordings=RecordingSet.from_file(
|
||||
manifest_dir / "csj_recordings_excluded.jsonl.gz"
|
||||
),
|
||||
supervisions=concat_csj_supervisions(sps, **concat_params),
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
||||
cut_set = cut_set.shuffle(Random(RNG_SEED))
|
||||
cut_sets.append(("excluded", cut_set))
|
||||
|
||||
logging.info(
|
||||
"Creating valid and train cuts from core and noncore, split at {split}."
|
||||
# Create valid dataset
|
||||
sps = sorted(
|
||||
SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=RecordingSet.from_file(
|
||||
manifest_dir / "csj_recordings_valid.jsonl.gz"
|
||||
),
|
||||
supervisions=concat_csj_supervisions(sps, **concat_params),
|
||||
)
|
||||
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
||||
cut_sets.append(("valid", cut_set))
|
||||
|
||||
train_set = CutSet.from_cuts(islice(cut_set, split, None))
|
||||
logging.info("Creating train cuts.")
|
||||
|
||||
# Create train dataset
|
||||
sps = sorted(
|
||||
SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz")
|
||||
+ SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"),
|
||||
key=lambda x: x.id,
|
||||
)
|
||||
|
||||
recording = RecordingSet.from_file(
|
||||
manifest_dir / "csj_recordings_core.jsonl.gz"
|
||||
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
|
||||
|
||||
train_set = CutSet.from_manifests(
|
||||
recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params)
|
||||
).trim_to_supervisions(keep_overlapping=False)
|
||||
train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
|
||||
|
||||
cut_sets.extend([("valid", valid_set), ("train", train_set)])
|
||||
cut_sets.append(("train", train_set))
|
||||
|
||||
return cut_sets
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=ARGPARSE_DESCRIPTION,
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
|
||||
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
|
||||
parser.add_argument("--split", type=int, default=4000, help="Split at this index")
|
||||
parser.add_argument(
|
||||
"-m", "--manifest-dir", type=Path, help="Path to save manifests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f", "--fbank-dir", type=Path, help="Path to save fbank features"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@ -138,7 +161,7 @@ def main():
|
||||
)
|
||||
return
|
||||
else:
|
||||
cut_sets = make_cutset_blueprints(args.manifest_dir, args.split)
|
||||
cut_sets = make_cutset_blueprints(args.manifest_dir)
|
||||
for part, cut_set in cut_sets:
|
||||
logging.info(f"Processing {part}")
|
||||
cut_set = cut_set.compute_and_store_features(
|
||||
@ -147,7 +170,7 @@ def main():
|
||||
storage_path=(args.fbank_dir / f"feats_{part}").as_posix(),
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz")
|
||||
cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz")
|
||||
|
||||
logging.info("All fbank computed for CSJ.")
|
||||
(args.fbank_dir / ".done").touch()
|
||||
|
@ -28,9 +28,7 @@ from icefall.utils import get_executor
|
||||
|
||||
ARGPARSE_DESCRIPTION = """
|
||||
This file computes fbank features of the musan dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
@ -42,8 +40,6 @@ torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
|
||||
# src_dir = Path("data/manifests")
|
||||
# output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_mel_bins = 80
|
||||
|
||||
@ -104,8 +100,12 @@ def get_args():
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
|
||||
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
|
||||
parser.add_argument(
|
||||
"-m", "--manifest-dir", type=Path, help="Path to save manifests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-f", "--fbank-dir", type=Path, help="Path to save fbank features"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -1,320 +1,79 @@
|
||||
; # This section is ignored if this file is not supplied as the first config file to
|
||||
; # lhotse prepare csj
|
||||
[SEGMENTS]
|
||||
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
|
||||
gap = 0.5
|
||||
; # Maximum length of segment (s).
|
||||
maxlen = 10
|
||||
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
|
||||
minlen = 0.02
|
||||
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
|
||||
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
|
||||
; # If you intend to use a multicharacter string for gap_sym, remember to register the
|
||||
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
|
||||
gap_sym =
|
||||
|
||||
[CONSTANTS]
|
||||
; # Name of this mode
|
||||
MODE = disfluent
|
||||
; # Suffixes to use after the word surface (no longer used)
|
||||
MORPH = pos1 cForm cType2 pos2
|
||||
; # Used to differentiate between A tag and A_num tag
|
||||
JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
|
||||
; # Dummy character to delineate multiline words
|
||||
PLUS = +
|
||||
|
||||
[DECISIONS]
|
||||
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
|
||||
; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
|
||||
|
||||
; # フィラー、感情表出系感動詞
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(F ぎょっ)'
|
||||
F = 0
|
||||
; # Example: '(L (F ン))', '比べ(F えー)る'
|
||||
F^ = 0
|
||||
; # 言い直し、いいよどみなどによる語断片
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
||||
D = 0
|
||||
; # Example: '(L (D ドゥ)+(D ヒ))'
|
||||
D^ = 0
|
||||
; # 助詞、助動詞、接辞の言い直し
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
||||
D2 = 0
|
||||
; # Example: '(X (D2 ノ))'
|
||||
D2^ = 0
|
||||
; # 聞き取りや語彙の判断に自信がない場合
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: (? 字数) の
|
||||
; # If no option: empty string is returned regardless of output
|
||||
; # Example: '(?) で'
|
||||
? = 0
|
||||
; # Example: '(D (? すー))+そう+です+よ+ね'
|
||||
?^ = 0
|
||||
; # タグ?で、値は複数の候補が想定される場合
|
||||
; # 0 for main guess with matching morph info, 1 for second guess
|
||||
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
||||
?, = 0
|
||||
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
|
||||
?,^ = 0
|
||||
; # 音や言葉に関するメタ的な引用
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
||||
M = 0
|
||||
; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
|
||||
M^ = 0
|
||||
; # 外国語や古語、方言など
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(O ザッツファイン)'
|
||||
O = 0
|
||||
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
|
||||
O^ = 0
|
||||
; # 講演者の名前、差別語、誹謗中傷など
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '国語研の (R ××) です'
|
||||
R = 0
|
||||
R^ = 0
|
||||
; # 非朗読対象発話(朗読における言い間違い等)
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(X 実際は) 実際には'
|
||||
X = 0
|
||||
; # Example: '(L (X (D2 ニ)))'
|
||||
X^ = 0
|
||||
; # アルファベットや算用数字、記号の表記
|
||||
; # 0 to use Japanese form, 1 to use alphabet form
|
||||
; # Example: '(A シーディーアール;CD-R)'
|
||||
A = 1
|
||||
; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
|
||||
A^ = 1
|
||||
; # タグAで、単語は算用数字の場合
|
||||
; # 0 to use Japanese form, 1 to use Arabic numerals
|
||||
; # Example: (A 二千;2000)
|
||||
A_num = eval:self.notag
|
||||
A_num^ = eval:self.notag
|
||||
A_num = 0
|
||||
; # 何らかの原因で漢字表記できなくなった場合
|
||||
; # 0 to use broken form, 1 to use orthodox form
|
||||
; # Example: '(K たち (F えー) ばな;橘)'
|
||||
K = 1
|
||||
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
|
||||
K^ = 1
|
||||
; # 転訛、発音の怠けなど、一時的な発音エラー
|
||||
; # 0 to use wrong form, 1 to use orthodox form
|
||||
; # Example: '(W ギーツ;ギジュツ)'
|
||||
W = 1
|
||||
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
|
||||
W^ = 1
|
||||
; # 語の読みに関する知識レベルのいい間違い
|
||||
; # 0 to use wrong form, 1 to use orthodox form
|
||||
; # Example: '(B シブタイ;ジュータイ)'
|
||||
B = 0
|
||||
; # Example: 'データー(B カズ;スー)'
|
||||
B^ = 0
|
||||
; # 笑いながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
||||
笑 = 0
|
||||
; # Example: 'コク(笑 サイ+(D オン))',
|
||||
笑^ = 0
|
||||
; # 泣きながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(泣 ドンナニ)'
|
||||
泣 = 0
|
||||
泣^ = 0
|
||||
; # 咳をしながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: 'シャ(咳 リン) ノ'
|
||||
咳 = 0
|
||||
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
|
||||
咳^ = 0
|
||||
; # ささやき声や独り言などの小さな声
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
||||
L = 0
|
||||
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
|
||||
L^ = 0
|
||||
|
||||
[REPLACEMENTS]
|
||||
; # ボーカルフライなどで母音が同定できない場合
|
||||
<FV> =
|
||||
; # 「うん/うーん/ふーん」の音の特定が困難な場合
|
||||
<VN> =
|
||||
; # 非語彙的な母音の引き延ばし
|
||||
<H> =
|
||||
; # 非語彙的な子音の引き延ばし
|
||||
<Q> =
|
||||
; # 言語音と独立に講演者の笑いが生じている場合
|
||||
<笑> =
|
||||
; # 言語音と独立に講演者の咳が生じている場合
|
||||
<咳> =
|
||||
; # 言語音と独立に講演者の息が生じている場合
|
||||
<息> =
|
||||
; # 講演者の泣き声
|
||||
<泣> =
|
||||
; # 聴衆(司会者なども含む)の発話
|
||||
<フロア発話> =
|
||||
; # 聴衆の笑い
|
||||
<フロア笑> =
|
||||
; # 聴衆の拍手
|
||||
<拍手> =
|
||||
; # 講演者が発表中に用いたデモンストレーションの音声
|
||||
<デモ> =
|
||||
; # 学会講演に発表時間を知らせるためにならすベルの音
|
||||
<ベル> =
|
||||
; # 転記単位全体が再度読み直された場合
|
||||
<朗読間違い> =
|
||||
; # 上記以外の音で特に目立った音
|
||||
<雑音> =
|
||||
; # 0.2秒以上のポーズ
|
||||
<P> =
|
||||
; # Redacted information, for R
|
||||
; # It is \x00D7 multiplication sign, not your normal 'x'
|
||||
× = ×
|
||||
|
||||
[FIELDS]
|
||||
; # Time information for segment
|
||||
time = 3
|
||||
; # Word surface
|
||||
surface = 5
|
||||
; # Word surface root form without CSJ tags
|
||||
notag = 9
|
||||
; # Part Of Speech
|
||||
pos1 = 11
|
||||
; # Conjugated Form
|
||||
cForm = 12
|
||||
; # Conjugation Type
|
||||
cType1 = 13
|
||||
; # Subcategory of POS
|
||||
pos2 = 14
|
||||
; # Euphonic Change / Subcategory of Conjugation Type
|
||||
cType2 = 15
|
||||
; # Other information
|
||||
other = 16
|
||||
; # Pronunciation for lexicon
|
||||
pron = 10
|
||||
; # Speaker ID
|
||||
spk_id = 2
|
||||
|
||||
[KATAKANA2ROMAJI]
|
||||
ア = 'a
|
||||
イ = 'i
|
||||
ウ = 'u
|
||||
エ = 'e
|
||||
オ = 'o
|
||||
カ = ka
|
||||
キ = ki
|
||||
ク = ku
|
||||
ケ = ke
|
||||
コ = ko
|
||||
ガ = ga
|
||||
ギ = gi
|
||||
グ = gu
|
||||
ゲ = ge
|
||||
ゴ = go
|
||||
サ = sa
|
||||
シ = si
|
||||
ス = su
|
||||
セ = se
|
||||
ソ = so
|
||||
ザ = za
|
||||
ジ = zi
|
||||
ズ = zu
|
||||
ゼ = ze
|
||||
ゾ = zo
|
||||
タ = ta
|
||||
チ = ti
|
||||
ツ = tu
|
||||
テ = te
|
||||
ト = to
|
||||
ダ = da
|
||||
ヂ = di
|
||||
ヅ = du
|
||||
デ = de
|
||||
ド = do
|
||||
ナ = na
|
||||
ニ = ni
|
||||
ヌ = nu
|
||||
ネ = ne
|
||||
ノ = no
|
||||
ハ = ha
|
||||
ヒ = hi
|
||||
フ = hu
|
||||
ヘ = he
|
||||
ホ = ho
|
||||
バ = ba
|
||||
ビ = bi
|
||||
ブ = bu
|
||||
ベ = be
|
||||
ボ = bo
|
||||
パ = pa
|
||||
ピ = pi
|
||||
プ = pu
|
||||
ペ = pe
|
||||
ポ = po
|
||||
マ = ma
|
||||
ミ = mi
|
||||
ム = mu
|
||||
メ = me
|
||||
モ = mo
|
||||
ヤ = ya
|
||||
ユ = yu
|
||||
ヨ = yo
|
||||
ラ = ra
|
||||
リ = ri
|
||||
ル = ru
|
||||
レ = re
|
||||
ロ = ro
|
||||
ワ = wa
|
||||
ヰ = we
|
||||
ヱ = wi
|
||||
ヲ = wo
|
||||
ン = ŋ
|
||||
ッ = q
|
||||
ー = -
|
||||
キャ = kǐa
|
||||
キュ = kǐu
|
||||
キョ = kǐo
|
||||
ギャ = gǐa
|
||||
ギュ = gǐu
|
||||
ギョ = gǐo
|
||||
シャ = sǐa
|
||||
シュ = sǐu
|
||||
ショ = sǐo
|
||||
ジャ = zǐa
|
||||
ジュ = zǐu
|
||||
ジョ = zǐo
|
||||
チャ = tǐa
|
||||
チュ = tǐu
|
||||
チョ = tǐo
|
||||
ヂャ = dǐa
|
||||
ヂュ = dǐu
|
||||
ヂョ = dǐo
|
||||
ニャ = nǐa
|
||||
ニュ = nǐu
|
||||
ニョ = nǐo
|
||||
ヒャ = hǐa
|
||||
ヒュ = hǐu
|
||||
ヒョ = hǐo
|
||||
ビャ = bǐa
|
||||
ビュ = bǐu
|
||||
ビョ = bǐo
|
||||
ピャ = pǐa
|
||||
ピュ = pǐu
|
||||
ピョ = pǐo
|
||||
ミャ = mǐa
|
||||
ミュ = mǐu
|
||||
ミョ = mǐo
|
||||
リャ = rǐa
|
||||
リュ = rǐu
|
||||
リョ = rǐo
|
||||
ァ = a
|
||||
ィ = i
|
||||
ゥ = u
|
||||
ェ = e
|
||||
ォ = o
|
||||
ヮ = ʍ
|
||||
ヴ = vu
|
||||
ャ = ǐa
|
||||
ュ = ǐu
|
||||
ョ = ǐo
|
||||
|
@ -1,320 +1,79 @@
|
||||
; # This section is ignored if this file is not supplied as the first config file to
|
||||
; # lhotse prepare csj
|
||||
[SEGMENTS]
|
||||
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
|
||||
gap = 0.5
|
||||
; # Maximum length of segment (s).
|
||||
maxlen = 10
|
||||
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
|
||||
minlen = 0.02
|
||||
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
|
||||
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
|
||||
; # If you intend to use a multicharacter string for gap_sym, remember to register the
|
||||
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
|
||||
gap_sym =
|
||||
|
||||
[CONSTANTS]
|
||||
; # Name of this mode
|
||||
MODE = fluent
|
||||
; # Suffixes to use after the word surface (no longer used)
|
||||
MORPH = pos1 cForm cType2 pos2
|
||||
; # Used to differentiate between A tag and A_num tag
|
||||
JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
|
||||
; # Dummy character to delineate multiline words
|
||||
PLUS = +
|
||||
|
||||
[DECISIONS]
|
||||
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
|
||||
; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
|
||||
|
||||
; # フィラー、感情表出系感動詞
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(F ぎょっ)'
|
||||
F = 1
|
||||
; # Example: '(L (F ン))', '比べ(F えー)る'
|
||||
F^ = 1
|
||||
; # 言い直し、いいよどみなどによる語断片
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
||||
D = 1
|
||||
; # Example: '(L (D ドゥ)+(D ヒ))'
|
||||
D^ = 1
|
||||
; # 助詞、助動詞、接辞の言い直し
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
||||
D2 = 1
|
||||
; # Example: '(X (D2 ノ))'
|
||||
D2^ = 1
|
||||
; # 聞き取りや語彙の判断に自信がない場合
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: (? 字数) の
|
||||
; # If no option: empty string is returned regardless of output
|
||||
; # Example: '(?) で'
|
||||
? = 0
|
||||
; # Example: '(D (? すー))+そう+です+よ+ね'
|
||||
?^ = 0
|
||||
; # タグ?で、値は複数の候補が想定される場合
|
||||
; # 0 for main guess with matching morph info, 1 for second guess
|
||||
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
||||
?, = 0
|
||||
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
|
||||
?,^ = 0
|
||||
; # 音や言葉に関するメタ的な引用
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
||||
M = 0
|
||||
; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
|
||||
M^ = 0
|
||||
; # 外国語や古語、方言など
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(O ザッツファイン)'
|
||||
O = 0
|
||||
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
|
||||
O^ = 0
|
||||
; # 講演者の名前、差別語、誹謗中傷など
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '国語研の (R ××) です'
|
||||
R = 0
|
||||
R^ = 0
|
||||
; # 非朗読対象発話(朗読における言い間違い等)
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(X 実際は) 実際には'
|
||||
X = 0
|
||||
; # Example: '(L (X (D2 ニ)))'
|
||||
X^ = 0
|
||||
; # アルファベットや算用数字、記号の表記
|
||||
; # 0 to use Japanese form, 1 to use alphabet form
|
||||
; # Example: '(A シーディーアール;CD-R)'
|
||||
A = 1
|
||||
; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
|
||||
A^ = 1
|
||||
; # タグAで、単語は算用数字の場合
|
||||
; # 0 to use Japanese form, 1 to use Arabic numerals
|
||||
; # Example: (A 二千;2000)
|
||||
A_num = eval:self.notag
|
||||
A_num^ = eval:self.notag
|
||||
A_num = 0
|
||||
; # 何らかの原因で漢字表記できなくなった場合
|
||||
; # 0 to use broken form, 1 to use orthodox form
|
||||
; # Example: '(K たち (F えー) ばな;橘)'
|
||||
K = 1
|
||||
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
|
||||
K^ = 1
|
||||
; # 転訛、発音の怠けなど、一時的な発音エラー
|
||||
; # 0 to use wrong form, 1 to use orthodox form
|
||||
; # Example: '(W ギーツ;ギジュツ)'
|
||||
W = 1
|
||||
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
|
||||
W^ = 1
|
||||
; # 語の読みに関する知識レベルのいい間違い
|
||||
; # 0 to use wrong form, 1 to use orthodox form
|
||||
; # Example: '(B シブタイ;ジュータイ)'
|
||||
B = 0
|
||||
; # Example: 'データー(B カズ;スー)'
|
||||
B^ = 0
|
||||
; # 笑いながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
||||
笑 = 0
|
||||
; # Example: 'コク(笑 サイ+(D オン))',
|
||||
笑^ = 0
|
||||
; # 泣きながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(泣 ドンナニ)'
|
||||
泣 = 0
|
||||
泣^ = 0
|
||||
; # 咳をしながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: 'シャ(咳 リン) ノ'
|
||||
咳 = 0
|
||||
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
|
||||
咳^ = 0
|
||||
; # ささやき声や独り言などの小さな声
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
||||
L = 0
|
||||
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
|
||||
L^ = 0
|
||||
|
||||
[REPLACEMENTS]
|
||||
; # ボーカルフライなどで母音が同定できない場合
|
||||
<FV> =
|
||||
; # 「うん/うーん/ふーん」の音の特定が困難な場合
|
||||
<VN> =
|
||||
; # 非語彙的な母音の引き延ばし
|
||||
<H> =
|
||||
; # 非語彙的な子音の引き延ばし
|
||||
<Q> =
|
||||
; # 言語音と独立に講演者の笑いが生じている場合
|
||||
<笑> =
|
||||
; # 言語音と独立に講演者の咳が生じている場合
|
||||
<咳> =
|
||||
; # 言語音と独立に講演者の息が生じている場合
|
||||
<息> =
|
||||
; # 講演者の泣き声
|
||||
<泣> =
|
||||
; # 聴衆(司会者なども含む)の発話
|
||||
<フロア発話> =
|
||||
; # 聴衆の笑い
|
||||
<フロア笑> =
|
||||
; # 聴衆の拍手
|
||||
<拍手> =
|
||||
; # 講演者が発表中に用いたデモンストレーションの音声
|
||||
<デモ> =
|
||||
; # 学会講演に発表時間を知らせるためにならすベルの音
|
||||
<ベル> =
|
||||
; # 転記単位全体が再度読み直された場合
|
||||
<朗読間違い> =
|
||||
; # 上記以外の音で特に目立った音
|
||||
<雑音> =
|
||||
; # 0.2秒以上のポーズ
|
||||
<P> =
|
||||
; # Redacted information, for R
|
||||
; # It is \x00D7 multiplication sign, not your normal 'x'
|
||||
× = ×
|
||||
|
||||
[FIELDS]
|
||||
; # Time information for segment
|
||||
time = 3
|
||||
; # Word surface
|
||||
surface = 5
|
||||
; # Word surface root form without CSJ tags
|
||||
notag = 9
|
||||
; # Part Of Speech
|
||||
pos1 = 11
|
||||
; # Conjugated Form
|
||||
cForm = 12
|
||||
; # Conjugation Type
|
||||
cType1 = 13
|
||||
; # Subcategory of POS
|
||||
pos2 = 14
|
||||
; # Euphonic Change / Subcategory of Conjugation Type
|
||||
cType2 = 15
|
||||
; # Other information
|
||||
other = 16
|
||||
; # Pronunciation for lexicon
|
||||
pron = 10
|
||||
; # Speaker ID
|
||||
spk_id = 2
|
||||
|
||||
[KATAKANA2ROMAJI]
|
||||
ア = 'a
|
||||
イ = 'i
|
||||
ウ = 'u
|
||||
エ = 'e
|
||||
オ = 'o
|
||||
カ = ka
|
||||
キ = ki
|
||||
ク = ku
|
||||
ケ = ke
|
||||
コ = ko
|
||||
ガ = ga
|
||||
ギ = gi
|
||||
グ = gu
|
||||
ゲ = ge
|
||||
ゴ = go
|
||||
サ = sa
|
||||
シ = si
|
||||
ス = su
|
||||
セ = se
|
||||
ソ = so
|
||||
ザ = za
|
||||
ジ = zi
|
||||
ズ = zu
|
||||
ゼ = ze
|
||||
ゾ = zo
|
||||
タ = ta
|
||||
チ = ti
|
||||
ツ = tu
|
||||
テ = te
|
||||
ト = to
|
||||
ダ = da
|
||||
ヂ = di
|
||||
ヅ = du
|
||||
デ = de
|
||||
ド = do
|
||||
ナ = na
|
||||
ニ = ni
|
||||
ヌ = nu
|
||||
ネ = ne
|
||||
ノ = no
|
||||
ハ = ha
|
||||
ヒ = hi
|
||||
フ = hu
|
||||
ヘ = he
|
||||
ホ = ho
|
||||
バ = ba
|
||||
ビ = bi
|
||||
ブ = bu
|
||||
ベ = be
|
||||
ボ = bo
|
||||
パ = pa
|
||||
ピ = pi
|
||||
プ = pu
|
||||
ペ = pe
|
||||
ポ = po
|
||||
マ = ma
|
||||
ミ = mi
|
||||
ム = mu
|
||||
メ = me
|
||||
モ = mo
|
||||
ヤ = ya
|
||||
ユ = yu
|
||||
ヨ = yo
|
||||
ラ = ra
|
||||
リ = ri
|
||||
ル = ru
|
||||
レ = re
|
||||
ロ = ro
|
||||
ワ = wa
|
||||
ヰ = we
|
||||
ヱ = wi
|
||||
ヲ = wo
|
||||
ン = ŋ
|
||||
ッ = q
|
||||
ー = -
|
||||
キャ = kǐa
|
||||
キュ = kǐu
|
||||
キョ = kǐo
|
||||
ギャ = gǐa
|
||||
ギュ = gǐu
|
||||
ギョ = gǐo
|
||||
シャ = sǐa
|
||||
シュ = sǐu
|
||||
ショ = sǐo
|
||||
ジャ = zǐa
|
||||
ジュ = zǐu
|
||||
ジョ = zǐo
|
||||
チャ = tǐa
|
||||
チュ = tǐu
|
||||
チョ = tǐo
|
||||
ヂャ = dǐa
|
||||
ヂュ = dǐu
|
||||
ヂョ = dǐo
|
||||
ニャ = nǐa
|
||||
ニュ = nǐu
|
||||
ニョ = nǐo
|
||||
ヒャ = hǐa
|
||||
ヒュ = hǐu
|
||||
ヒョ = hǐo
|
||||
ビャ = bǐa
|
||||
ビュ = bǐu
|
||||
ビョ = bǐo
|
||||
ピャ = pǐa
|
||||
ピュ = pǐu
|
||||
ピョ = pǐo
|
||||
ミャ = mǐa
|
||||
ミュ = mǐu
|
||||
ミョ = mǐo
|
||||
リャ = rǐa
|
||||
リュ = rǐu
|
||||
リョ = rǐo
|
||||
ァ = a
|
||||
ィ = i
|
||||
ゥ = u
|
||||
ェ = e
|
||||
ォ = o
|
||||
ヮ = ʍ
|
||||
ヴ = vu
|
||||
ャ = ǐa
|
||||
ュ = ǐu
|
||||
ョ = ǐo
|
||||
|
@ -1,320 +1,79 @@
|
||||
; # This section is ignored if this file is not supplied as the first config file to
|
||||
; # lhotse prepare csj
|
||||
[SEGMENTS]
|
||||
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
|
||||
gap = 0.5
|
||||
; # Maximum length of segment (s).
|
||||
maxlen = 10
|
||||
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
|
||||
minlen = 0.02
|
||||
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
|
||||
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
|
||||
; # If you intend to use a multicharacter string for gap_sym, remember to register the
|
||||
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
|
||||
gap_sym =
|
||||
|
||||
[CONSTANTS]
|
||||
; # Name of this mode
|
||||
MODE = number
|
||||
; # Suffixes to use after the word surface (no longer used)
|
||||
MORPH = pos1 cForm cType2 pos2
|
||||
; # Used to differentiate between A tag and A_num tag
|
||||
JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
|
||||
; # Dummy character to delineate multiline words
|
||||
PLUS = +
|
||||
|
||||
[DECISIONS]
|
||||
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
|
||||
; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
|
||||
|
||||
; # フィラー、感情表出系感動詞
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(F ぎょっ)'
|
||||
F = 1
|
||||
; # Example: '(L (F ン))', '比べ(F えー)る'
|
||||
F^ = 1
|
||||
; # 言い直し、いいよどみなどによる語断片
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
||||
D = 1
|
||||
; # Example: '(L (D ドゥ)+(D ヒ))'
|
||||
D^ = 1
|
||||
; # 助詞、助動詞、接辞の言い直し
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
||||
D2 = 1
|
||||
; # Example: '(X (D2 ノ))'
|
||||
D2^ = 1
|
||||
; # 聞き取りや語彙の判断に自信がない場合
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: (? 字数) の
|
||||
; # If no option: empty string is returned regardless of output
|
||||
; # Example: '(?) で'
|
||||
? = 0
|
||||
; # Example: '(D (? すー))+そう+です+よ+ね'
|
||||
?^ = 0
|
||||
; # タグ?で、値は複数の候補が想定される場合
|
||||
; # 0 for main guess with matching morph info, 1 for second guess
|
||||
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
||||
?, = 0
|
||||
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
|
||||
?,^ = 0
|
||||
; # 音や言葉に関するメタ的な引用
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
||||
M = 0
|
||||
; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
|
||||
M^ = 0
|
||||
; # 外国語や古語、方言など
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(O ザッツファイン)'
|
||||
O = 0
|
||||
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
|
||||
O^ = 0
|
||||
; # 講演者の名前、差別語、誹謗中傷など
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '国語研の (R ××) です'
|
||||
R = 0
|
||||
R^ = 0
|
||||
; # 非朗読対象発話(朗読における言い間違い等)
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(X 実際は) 実際には'
|
||||
X = 0
|
||||
; # Example: '(L (X (D2 ニ)))'
|
||||
X^ = 0
|
||||
; # アルファベットや算用数字、記号の表記
|
||||
; # 0 to use Japanese form, 1 to use alphabet form
|
||||
; # Example: '(A シーディーアール;CD-R)'
|
||||
A = 1
|
||||
; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
|
||||
A^ = 1
|
||||
; # タグAで、単語は算用数字の場合
|
||||
; # 0 to use Japanese form, 1 to use Arabic numerals
|
||||
; # Example: (A 二千;2000)
|
||||
A_num = 1
|
||||
A_num^ = 1
|
||||
; # 何らかの原因で漢字表記できなくなった場合
|
||||
; # 0 to use broken form, 1 to use orthodox form
|
||||
; # Example: '(K たち (F えー) ばな;橘)'
|
||||
K = 1
|
||||
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
|
||||
K^ = 1
|
||||
; # 転訛、発音の怠けなど、一時的な発音エラー
|
||||
; # 0 to use wrong form, 1 to use orthodox form
|
||||
; # Example: '(W ギーツ;ギジュツ)'
|
||||
W = 1
|
||||
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
|
||||
W^ = 1
|
||||
; # 語の読みに関する知識レベルのいい間違い
|
||||
; # 0 to use wrong form, 1 to use orthodox form
|
||||
; # Example: '(B シブタイ;ジュータイ)'
|
||||
B = 0
|
||||
; # Example: 'データー(B カズ;スー)'
|
||||
B^ = 0
|
||||
; # 笑いながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
||||
笑 = 0
|
||||
; # Example: 'コク(笑 サイ+(D オン))',
|
||||
笑^ = 0
|
||||
; # 泣きながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(泣 ドンナニ)'
|
||||
泣 = 0
|
||||
泣^ = 0
|
||||
; # 咳をしながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: 'シャ(咳 リン) ノ'
|
||||
咳 = 0
|
||||
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
|
||||
咳^ = 0
|
||||
; # ささやき声や独り言などの小さな声
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
||||
L = 0
|
||||
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
|
||||
L^ = 0
|
||||
|
||||
[REPLACEMENTS]
|
||||
; # ボーカルフライなどで母音が同定できない場合
|
||||
<FV> =
|
||||
; # 「うん/うーん/ふーん」の音の特定が困難な場合
|
||||
<VN> =
|
||||
; # 非語彙的な母音の引き延ばし
|
||||
<H> =
|
||||
; # 非語彙的な子音の引き延ばし
|
||||
<Q> =
|
||||
; # 言語音と独立に講演者の笑いが生じている場合
|
||||
<笑> =
|
||||
; # 言語音と独立に講演者の咳が生じている場合
|
||||
<咳> =
|
||||
; # 言語音と独立に講演者の息が生じている場合
|
||||
<息> =
|
||||
; # 講演者の泣き声
|
||||
<泣> =
|
||||
; # 聴衆(司会者なども含む)の発話
|
||||
<フロア発話> =
|
||||
; # 聴衆の笑い
|
||||
<フロア笑> =
|
||||
; # 聴衆の拍手
|
||||
<拍手> =
|
||||
; # 講演者が発表中に用いたデモンストレーションの音声
|
||||
<デモ> =
|
||||
; # 学会講演に発表時間を知らせるためにならすベルの音
|
||||
<ベル> =
|
||||
; # 転記単位全体が再度読み直された場合
|
||||
<朗読間違い> =
|
||||
; # 上記以外の音で特に目立った音
|
||||
<雑音> =
|
||||
; # 0.2秒以上のポーズ
|
||||
<P> =
|
||||
; # Redacted information, for R
|
||||
; # It is \x00D7 multiplication sign, not your normal 'x'
|
||||
× = ×
|
||||
|
||||
[FIELDS]
|
||||
; # Time information for segment
|
||||
time = 3
|
||||
; # Word surface
|
||||
surface = 5
|
||||
; # Word surface root form without CSJ tags
|
||||
notag = 9
|
||||
; # Part Of Speech
|
||||
pos1 = 11
|
||||
; # Conjugated Form
|
||||
cForm = 12
|
||||
; # Conjugation Type
|
||||
cType1 = 13
|
||||
; # Subcategory of POS
|
||||
pos2 = 14
|
||||
; # Euphonic Change / Subcategory of Conjugation Type
|
||||
cType2 = 15
|
||||
; # Other information
|
||||
other = 16
|
||||
; # Pronunciation for lexicon
|
||||
pron = 10
|
||||
; # Speaker ID
|
||||
spk_id = 2
|
||||
|
||||
[KATAKANA2ROMAJI]
|
||||
ア = 'a
|
||||
イ = 'i
|
||||
ウ = 'u
|
||||
エ = 'e
|
||||
オ = 'o
|
||||
カ = ka
|
||||
キ = ki
|
||||
ク = ku
|
||||
ケ = ke
|
||||
コ = ko
|
||||
ガ = ga
|
||||
ギ = gi
|
||||
グ = gu
|
||||
ゲ = ge
|
||||
ゴ = go
|
||||
サ = sa
|
||||
シ = si
|
||||
ス = su
|
||||
セ = se
|
||||
ソ = so
|
||||
ザ = za
|
||||
ジ = zi
|
||||
ズ = zu
|
||||
ゼ = ze
|
||||
ゾ = zo
|
||||
タ = ta
|
||||
チ = ti
|
||||
ツ = tu
|
||||
テ = te
|
||||
ト = to
|
||||
ダ = da
|
||||
ヂ = di
|
||||
ヅ = du
|
||||
デ = de
|
||||
ド = do
|
||||
ナ = na
|
||||
ニ = ni
|
||||
ヌ = nu
|
||||
ネ = ne
|
||||
ノ = no
|
||||
ハ = ha
|
||||
ヒ = hi
|
||||
フ = hu
|
||||
ヘ = he
|
||||
ホ = ho
|
||||
バ = ba
|
||||
ビ = bi
|
||||
ブ = bu
|
||||
ベ = be
|
||||
ボ = bo
|
||||
パ = pa
|
||||
ピ = pi
|
||||
プ = pu
|
||||
ペ = pe
|
||||
ポ = po
|
||||
マ = ma
|
||||
ミ = mi
|
||||
ム = mu
|
||||
メ = me
|
||||
モ = mo
|
||||
ヤ = ya
|
||||
ユ = yu
|
||||
ヨ = yo
|
||||
ラ = ra
|
||||
リ = ri
|
||||
ル = ru
|
||||
レ = re
|
||||
ロ = ro
|
||||
ワ = wa
|
||||
ヰ = we
|
||||
ヱ = wi
|
||||
ヲ = wo
|
||||
ン = ŋ
|
||||
ッ = q
|
||||
ー = -
|
||||
キャ = kǐa
|
||||
キュ = kǐu
|
||||
キョ = kǐo
|
||||
ギャ = gǐa
|
||||
ギュ = gǐu
|
||||
ギョ = gǐo
|
||||
シャ = sǐa
|
||||
シュ = sǐu
|
||||
ショ = sǐo
|
||||
ジャ = zǐa
|
||||
ジュ = zǐu
|
||||
ジョ = zǐo
|
||||
チャ = tǐa
|
||||
チュ = tǐu
|
||||
チョ = tǐo
|
||||
ヂャ = dǐa
|
||||
ヂュ = dǐu
|
||||
ヂョ = dǐo
|
||||
ニャ = nǐa
|
||||
ニュ = nǐu
|
||||
ニョ = nǐo
|
||||
ヒャ = hǐa
|
||||
ヒュ = hǐu
|
||||
ヒョ = hǐo
|
||||
ビャ = bǐa
|
||||
ビュ = bǐu
|
||||
ビョ = bǐo
|
||||
ピャ = pǐa
|
||||
ピュ = pǐu
|
||||
ピョ = pǐo
|
||||
ミャ = mǐa
|
||||
ミュ = mǐu
|
||||
ミョ = mǐo
|
||||
リャ = rǐa
|
||||
リュ = rǐu
|
||||
リョ = rǐo
|
||||
ァ = a
|
||||
ィ = i
|
||||
ゥ = u
|
||||
ェ = e
|
||||
ォ = o
|
||||
ヮ = ʍ
|
||||
ヴ = vu
|
||||
ャ = ǐa
|
||||
ュ = ǐu
|
||||
ョ = ǐo
|
||||
|
@ -1,321 +1,80 @@
|
||||
; # This section is ignored if this file is not supplied as the first config file to
|
||||
; # lhotse prepare csj
|
||||
[SEGMENTS]
|
||||
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
|
||||
gap = 0.5
|
||||
; # Maximum length of segment (s).
|
||||
maxlen = 10
|
||||
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
|
||||
minlen = 0.02
|
||||
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
|
||||
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
|
||||
; # If you intend to use a multicharacter string for gap_sym, remember to register the
|
||||
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
|
||||
gap_sym =
|
||||
|
||||
[CONSTANTS]
|
||||
; # Name of this mode
|
||||
; # See https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf
|
||||
; # From https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf
|
||||
MODE = symbol
|
||||
; # Suffixes to use after the word surface (no longer used)
|
||||
MORPH = pos1 cForm cType2 pos2
|
||||
; # Used to differentiate between A tag and A_num tag
|
||||
JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
|
||||
; # Dummy character to delineate multiline words
|
||||
PLUS = +
|
||||
|
||||
[DECISIONS]
|
||||
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
|
||||
; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
|
||||
|
||||
; # フィラー、感情表出系感動詞
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(F ぎょっ)'
|
||||
F = #
|
||||
; # Example: '(L (F ン))', '比べ(F えー)る'
|
||||
F^ = #
|
||||
F = "#", ["F"]
|
||||
; # 言い直し、いいよどみなどによる語断片
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
||||
D = @
|
||||
; # Example: '(L (D ドゥ)+(D ヒ))'
|
||||
D^ = @
|
||||
D = "@", ["D"]
|
||||
; # 助詞、助動詞、接辞の言い直し
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
||||
D2 = @
|
||||
; # Example: '(X (D2 ノ))'
|
||||
D2^ = @
|
||||
D2 = "@", ["D2"]
|
||||
; # 聞き取りや語彙の判断に自信がない場合
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: (? 字数) の
|
||||
; # If no option: empty string is returned regardless of output
|
||||
; # Example: '(?) で'
|
||||
? = 0
|
||||
; # Example: '(D (? すー))+そう+です+よ+ね'
|
||||
?^ = 0
|
||||
; # タグ?で、値は複数の候補が想定される場合
|
||||
; # 0 for main guess with matching morph info, 1 for second guess
|
||||
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
||||
?, = 0
|
||||
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
|
||||
?,^ = 0
|
||||
; # 音や言葉に関するメタ的な引用
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
||||
M = 0
|
||||
; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
|
||||
M^ = 0
|
||||
; # 外国語や古語、方言など
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(O ザッツファイン)'
|
||||
O = 0
|
||||
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
|
||||
O^ = 0
|
||||
; # 講演者の名前、差別語、誹謗中傷など
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '国語研の (R ××) です'
|
||||
R = 0
|
||||
R^ = 0
|
||||
; # 非朗読対象発話(朗読における言い間違い等)
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(X 実際は) 実際には'
|
||||
X = 0
|
||||
; # Example: '(L (X (D2 ニ)))'
|
||||
X^ = 0
|
||||
; # アルファベットや算用数字、記号の表記
|
||||
; # 0 to use Japanese form, 1 to use alphabet form
|
||||
; # Example: '(A シーディーアール;CD-R)'
|
||||
A = 1
|
||||
; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
|
||||
A^ = 1
|
||||
; # タグAで、単語は算用数字の場合
|
||||
; # 0 to use Japanese form, 1 to use Arabic numerals
|
||||
; # Example: (A 二千;2000)
|
||||
A_num = eval:self.notag
|
||||
A_num^ = eval:self.notag
|
||||
A_num = 1
|
||||
; # 何らかの原因で漢字表記できなくなった場合
|
||||
; # 0 to use broken form, 1 to use orthodox form
|
||||
; # Example: '(K たち (F えー) ばな;橘)'
|
||||
K = 1
|
||||
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
|
||||
K^ = 1
|
||||
; # 転訛、発音の怠けなど、一時的な発音エラー
|
||||
; # 0 to use wrong form, 1 to use orthodox form
|
||||
; # Example: '(W ギーツ;ギジュツ)'
|
||||
W = 1
|
||||
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
|
||||
W^ = 1
|
||||
; # 語の読みに関する知識レベルのいい間違い
|
||||
; # 0 to use wrong form, 1 to use orthodox form
|
||||
; # Example: '(B シブタイ;ジュータイ)'
|
||||
B = 0
|
||||
; # Example: 'データー(B カズ;スー)'
|
||||
B^ = 0
|
||||
; # 笑いながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
||||
笑 = 0
|
||||
; # Example: 'コク(笑 サイ+(D オン))',
|
||||
笑^ = 0
|
||||
; # 泣きながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(泣 ドンナニ)'
|
||||
泣 = 0
|
||||
泣^ = 0
|
||||
; # 咳をしながら発話
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: 'シャ(咳 リン) ノ'
|
||||
咳 = 0
|
||||
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
|
||||
咳^ = 0
|
||||
; # ささやき声や独り言などの小さな声
|
||||
; # 0 to remain, 1 to delete
|
||||
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
||||
L = 0
|
||||
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
|
||||
L^ = 0
|
||||
|
||||
[REPLACEMENTS]
|
||||
; # ボーカルフライなどで母音が同定できない場合
|
||||
<FV> =
|
||||
; # 「うん/うーん/ふーん」の音の特定が困難な場合
|
||||
<VN> =
|
||||
; # 非語彙的な母音の引き延ばし
|
||||
<H> =
|
||||
; # 非語彙的な子音の引き延ばし
|
||||
<Q> =
|
||||
; # 言語音と独立に講演者の笑いが生じている場合
|
||||
<笑> =
|
||||
; # 言語音と独立に講演者の咳が生じている場合
|
||||
<咳> =
|
||||
; # 言語音と独立に講演者の息が生じている場合
|
||||
<息> =
|
||||
; # 講演者の泣き声
|
||||
<泣> =
|
||||
; # 聴衆(司会者なども含む)の発話
|
||||
<フロア発話> =
|
||||
; # 聴衆の笑い
|
||||
<フロア笑> =
|
||||
; # 聴衆の拍手
|
||||
<拍手> =
|
||||
; # 講演者が発表中に用いたデモンストレーションの音声
|
||||
<デモ> =
|
||||
; # 学会講演に発表時間を知らせるためにならすベルの音
|
||||
<ベル> =
|
||||
; # 転記単位全体が再度読み直された場合
|
||||
<朗読間違い> =
|
||||
; # 上記以外の音で特に目立った音
|
||||
<雑音> =
|
||||
; # 0.2秒以上のポーズ
|
||||
<P> =
|
||||
; # Redacted information, for R
|
||||
; # It is \x00D7 multiplication sign, not your normal 'x'
|
||||
× = ×
|
||||
|
||||
[FIELDS]
|
||||
; # Time information for segment
|
||||
time = 3
|
||||
; # Word surface
|
||||
surface = 5
|
||||
; # Word surface root form without CSJ tags
|
||||
notag = 9
|
||||
; # Part Of Speech
|
||||
pos1 = 11
|
||||
; # Conjugated Form
|
||||
cForm = 12
|
||||
; # Conjugation Type
|
||||
cType1 = 13
|
||||
; # Subcategory of POS
|
||||
pos2 = 14
|
||||
; # Euphonic Change / Subcategory of Conjugation Type
|
||||
cType2 = 15
|
||||
; # Other information
|
||||
other = 16
|
||||
; # Pronunciation for lexicon
|
||||
pron = 10
|
||||
; # Speaker ID
|
||||
spk_id = 2
|
||||
|
||||
[KATAKANA2ROMAJI]
|
||||
ア = 'a
|
||||
イ = 'i
|
||||
ウ = 'u
|
||||
エ = 'e
|
||||
オ = 'o
|
||||
カ = ka
|
||||
キ = ki
|
||||
ク = ku
|
||||
ケ = ke
|
||||
コ = ko
|
||||
ガ = ga
|
||||
ギ = gi
|
||||
グ = gu
|
||||
ゲ = ge
|
||||
ゴ = go
|
||||
サ = sa
|
||||
シ = si
|
||||
ス = su
|
||||
セ = se
|
||||
ソ = so
|
||||
ザ = za
|
||||
ジ = zi
|
||||
ズ = zu
|
||||
ゼ = ze
|
||||
ゾ = zo
|
||||
タ = ta
|
||||
チ = ti
|
||||
ツ = tu
|
||||
テ = te
|
||||
ト = to
|
||||
ダ = da
|
||||
ヂ = di
|
||||
ヅ = du
|
||||
デ = de
|
||||
ド = do
|
||||
ナ = na
|
||||
ニ = ni
|
||||
ヌ = nu
|
||||
ネ = ne
|
||||
ノ = no
|
||||
ハ = ha
|
||||
ヒ = hi
|
||||
フ = hu
|
||||
ヘ = he
|
||||
ホ = ho
|
||||
バ = ba
|
||||
ビ = bi
|
||||
ブ = bu
|
||||
ベ = be
|
||||
ボ = bo
|
||||
パ = pa
|
||||
ピ = pi
|
||||
プ = pu
|
||||
ペ = pe
|
||||
ポ = po
|
||||
マ = ma
|
||||
ミ = mi
|
||||
ム = mu
|
||||
メ = me
|
||||
モ = mo
|
||||
ヤ = ya
|
||||
ユ = yu
|
||||
ヨ = yo
|
||||
ラ = ra
|
||||
リ = ri
|
||||
ル = ru
|
||||
レ = re
|
||||
ロ = ro
|
||||
ワ = wa
|
||||
ヰ = we
|
||||
ヱ = wi
|
||||
ヲ = wo
|
||||
ン = ŋ
|
||||
ッ = q
|
||||
ー = -
|
||||
キャ = kǐa
|
||||
キュ = kǐu
|
||||
キョ = kǐo
|
||||
ギャ = gǐa
|
||||
ギュ = gǐu
|
||||
ギョ = gǐo
|
||||
シャ = sǐa
|
||||
シュ = sǐu
|
||||
ショ = sǐo
|
||||
ジャ = zǐa
|
||||
ジュ = zǐu
|
||||
ジョ = zǐo
|
||||
チャ = tǐa
|
||||
チュ = tǐu
|
||||
チョ = tǐo
|
||||
ヂャ = dǐa
|
||||
ヂュ = dǐu
|
||||
ヂョ = dǐo
|
||||
ニャ = nǐa
|
||||
ニュ = nǐu
|
||||
ニョ = nǐo
|
||||
ヒャ = hǐa
|
||||
ヒュ = hǐu
|
||||
ヒョ = hǐo
|
||||
ビャ = bǐa
|
||||
ビュ = bǐu
|
||||
ビョ = bǐo
|
||||
ピャ = pǐa
|
||||
ピュ = pǐu
|
||||
ピョ = pǐo
|
||||
ミャ = mǐa
|
||||
ミュ = mǐu
|
||||
ミョ = mǐo
|
||||
リャ = rǐa
|
||||
リュ = rǐu
|
||||
リョ = rǐo
|
||||
ァ = a
|
||||
ィ = i
|
||||
ゥ = u
|
||||
ェ = e
|
||||
ォ = o
|
||||
ヮ = ʍ
|
||||
ヴ = vu
|
||||
ャ = ǐa
|
||||
ュ = ǐu
|
||||
ョ = ǐo
|
||||
|
202
egs/csj/ASR/local/disfluent_recogs_to_fluent.py
Normal file
202
egs/csj/ASR/local/disfluent_recogs_to_fluent.py
Normal file
@ -0,0 +1,202 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import kaldialign
|
||||
from lhotse import CutSet
|
||||
|
||||
ARGPARSE_DESCRIPTION = """
|
||||
This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript,
|
||||
compares it against a fluent transcript, and saves the results in a separate directory.
|
||||
This is useful to compare disfluent models with fluent models on the same metric.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
description=ARGPARSE_DESCRIPTION,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recogs",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cut",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--res-dir", type=Path, required=True, help="Path to save results"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def d2f(stats):
|
||||
"""
|
||||
Compare the outputs of a disfluent model against a fluent reference.
|
||||
Indicates a disfluent model's performance only on the content words
|
||||
|
||||
CER^d_f = (sub_f + ins + del_f) / Nf
|
||||
|
||||
"""
|
||||
return stats["base"] / stats["Nf"]
|
||||
|
||||
|
||||
def calc_cer(refs, hyps):
|
||||
subs = {
|
||||
"F": 0,
|
||||
"D": 0,
|
||||
}
|
||||
ins = 0
|
||||
dels = {
|
||||
"F": 0,
|
||||
"D": 0,
|
||||
}
|
||||
cors = {
|
||||
"F": 0,
|
||||
"D": 0,
|
||||
}
|
||||
dis_ref_len = 0
|
||||
flu_ref_len = 0
|
||||
|
||||
for ref, hyp in zip(refs, hyps):
|
||||
assert (
|
||||
ref[0] == hyp[0]
|
||||
), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}."
|
||||
tag = ref[2].copy()
|
||||
ref = ref[1]
|
||||
dis_ref_len += len(ref)
|
||||
# Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively.
|
||||
flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)])
|
||||
hyp = hyp[1]
|
||||
ali = kaldialign.align(ref, hyp, "*")
|
||||
tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali]
|
||||
for tag, (ref_word, hyp_word) in zip(tags, ali):
|
||||
if "D" in tag or "F" in tag:
|
||||
tag = "D"
|
||||
else:
|
||||
tag = "F"
|
||||
|
||||
if ref_word == "*":
|
||||
ins += 1
|
||||
elif hyp_word == "*":
|
||||
dels[tag] += 1
|
||||
elif ref_word != hyp_word:
|
||||
subs[tag] += 1
|
||||
else:
|
||||
cors[tag] += 1
|
||||
|
||||
return {
|
||||
"subs": subs,
|
||||
"ins": ins,
|
||||
"dels": dels,
|
||||
"cors": cors,
|
||||
"dis_ref_len": dis_ref_len,
|
||||
"flu_ref_len": flu_ref_len,
|
||||
}
|
||||
|
||||
|
||||
def for_each_recogs(recogs_file: Path, refs, out_dir):
|
||||
hyps = []
|
||||
with recogs_file.open() as fin:
|
||||
for line in fin:
|
||||
if "ref" in line:
|
||||
continue
|
||||
cutid, hyp = line.split(":\thyp=")
|
||||
hyps.append((cutid, eval(hyp)))
|
||||
|
||||
assert len(refs) == len(
|
||||
hyps
|
||||
), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal."
|
||||
stats = calc_cer(refs, hyps)
|
||||
stat_table = ["tag,yes,no"]
|
||||
|
||||
for cer_type in ["subs", "dels", "cors", "ins"]:
|
||||
ret = f"{cer_type}"
|
||||
for df in ["D", "F"]:
|
||||
try:
|
||||
ret += f",{stats[cer_type][df]}"
|
||||
except TypeError:
|
||||
# insertions do not belong to F or D, and is not subscriptable.
|
||||
ret += f",{stats[cer_type]},"
|
||||
break
|
||||
stat_table.append(ret)
|
||||
stat_table = "\n".join(stat_table)
|
||||
|
||||
stats = {
|
||||
"subd": stats["subs"]["D"],
|
||||
"deld": stats["dels"]["D"],
|
||||
"cord": stats["cors"]["D"],
|
||||
"Nf": stats["flu_ref_len"],
|
||||
"base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"],
|
||||
}
|
||||
|
||||
cer = d2f(stats)
|
||||
results = [
|
||||
f"{cer:.2%}",
|
||||
f"Nf,{stats['Nf']}",
|
||||
]
|
||||
results = "\n".join(results)
|
||||
|
||||
with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout:
|
||||
fout.write(results)
|
||||
fout.write("\n\n")
|
||||
fout.write(stat_table)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
recogs_file: Path = args.recogs
|
||||
assert (
|
||||
recogs_file.is_file() or recogs_file.is_dir()
|
||||
), f"recogs_file cannot be found at {recogs_file}."
|
||||
|
||||
args.res_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"):
|
||||
assert (
|
||||
"csj_cuts" in args.cut.name
|
||||
), f"Expected {args.cut} to be a cuts manifest."
|
||||
|
||||
refs: CutSet = CutSet.from_file(args.cut)
|
||||
refs = sorted(
|
||||
[
|
||||
(
|
||||
e.id,
|
||||
list(e.supervisions[0].custom["disfluent"]),
|
||||
e.supervisions[0].custom["disfluent_tag"].split(","),
|
||||
)
|
||||
for e in refs
|
||||
],
|
||||
key=lambda x: x[0],
|
||||
)
|
||||
for_each_recogs(recogs_file, refs, args.res_dir)
|
||||
|
||||
elif recogs_file.is_dir():
|
||||
recogs_file_path = recogs_file
|
||||
for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]:
|
||||
refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz")
|
||||
refs = sorted(
|
||||
[
|
||||
(
|
||||
r.id,
|
||||
list(r.supervisions[0].custom["disfluent"]),
|
||||
r.supervisions[0].custom["disfluent_tag"].split(","),
|
||||
)
|
||||
for r in refs
|
||||
],
|
||||
key=lambda x: x[0],
|
||||
)
|
||||
for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"):
|
||||
for_each_recogs(recogs_file, refs, args.res_dir)
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unrecognised recogs file provided: {recogs_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -45,8 +45,8 @@ def get_parser():
|
||||
def main():
|
||||
args = get_parser()
|
||||
|
||||
for path in args.manifest_dir.glob("csj_cuts_*.jsonl.gz"):
|
||||
|
||||
for part in ["eval1", "eval2", "eval3", "valid", "excluded", "train"]:
|
||||
path = args.manifest_dir / f"csj_cuts_{part}.jsonl.gz"
|
||||
cuts: CutSet = load_manifest(path)
|
||||
|
||||
print("\n---------------------------------\n")
|
||||
@ -58,123 +58,271 @@ if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
## eval1
|
||||
Cuts count: 1272
|
||||
Total duration (hh:mm:ss): 01:50:07
|
||||
Speech duration (hh:mm:ss): 01:50:07 (100.0%)
|
||||
Duration statistics (seconds):
|
||||
mean 5.2
|
||||
std 3.9
|
||||
min 0.2
|
||||
25% 1.9
|
||||
50% 4.0
|
||||
75% 8.1
|
||||
99% 14.3
|
||||
99.5% 14.7
|
||||
99.9% 16.0
|
||||
max 16.9
|
||||
Recordings available: 1272
|
||||
Features available: 1272
|
||||
Supervisions available: 1272
|
||||
csj_cuts_eval1.jsonl.gz:
|
||||
Cut statistics:
|
||||
╒═══════════════════════════╤══════════╕
|
||||
│ Cuts count: │ 1023 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Total duration (hh:mm:ss) │ 01:55:40 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ mean │ 6.8 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ std │ 2.7 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ min │ 0.2 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 25% │ 4.9 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 50% │ 7.7 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 75% │ 9.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.5% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.9% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ max │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Recordings available: │ 1023 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Features available: │ 0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Supervisions available: │ 1023 │
|
||||
╘═══════════════════════════╧══════════╛
|
||||
SUPERVISION custom fields:
|
||||
- fluent (in 1272 cuts)
|
||||
- disfluent (in 1272 cuts)
|
||||
- number (in 1272 cuts)
|
||||
- symbol (in 1272 cuts)
|
||||
Speech duration statistics:
|
||||
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||
│ Total speech duration │ 01:55:40 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total speaking time duration │ 01:55:40 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||
|
||||
## eval2
|
||||
Cuts count: 1292
|
||||
Total duration (hh:mm:ss): 01:56:50
|
||||
Speech duration (hh:mm:ss): 01:56:50 (100.0%)
|
||||
Duration statistics (seconds):
|
||||
mean 5.4
|
||||
std 3.9
|
||||
min 0.1
|
||||
25% 2.1
|
||||
50% 4.6
|
||||
75% 8.6
|
||||
99% 14.1
|
||||
99.5% 15.2
|
||||
99.9% 16.1
|
||||
max 16.9
|
||||
Recordings available: 1292
|
||||
Features available: 1292
|
||||
Supervisions available: 1292
|
||||
SUPERVISION custom fields:
|
||||
- fluent (in 1292 cuts)
|
||||
- number (in 1292 cuts)
|
||||
- symbol (in 1292 cuts)
|
||||
- disfluent (in 1292 cuts)
|
||||
---------------------------------
|
||||
|
||||
## eval3
|
||||
Cuts count: 1385
|
||||
Total duration (hh:mm:ss): 01:19:21
|
||||
Speech duration (hh:mm:ss): 01:19:21 (100.0%)
|
||||
Duration statistics (seconds):
|
||||
mean 3.4
|
||||
std 3.0
|
||||
min 0.2
|
||||
25% 1.2
|
||||
50% 2.5
|
||||
75% 4.6
|
||||
99% 12.7
|
||||
99.5% 13.7
|
||||
99.9% 15.0
|
||||
max 15.9
|
||||
Recordings available: 1385
|
||||
Features available: 1385
|
||||
Supervisions available: 1385
|
||||
csj_cuts_eval2.jsonl.gz:
|
||||
Cut statistics:
|
||||
╒═══════════════════════════╤══════════╕
|
||||
│ Cuts count: │ 1025 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Total duration (hh:mm:ss) │ 02:02:07 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ mean │ 7.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ std │ 2.5 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ min │ 0.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 25% │ 5.9 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 50% │ 7.9 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 75% │ 9.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.5% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.9% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ max │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Recordings available: │ 1025 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Features available: │ 0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Supervisions available: │ 1025 │
|
||||
╘═══════════════════════════╧══════════╛
|
||||
SUPERVISION custom fields:
|
||||
- number (in 1385 cuts)
|
||||
- symbol (in 1385 cuts)
|
||||
- fluent (in 1385 cuts)
|
||||
- disfluent (in 1385 cuts)
|
||||
Speech duration statistics:
|
||||
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||
│ Total speech duration │ 02:02:07 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total speaking time duration │ 02:02:07 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||
|
||||
## valid
|
||||
Cuts count: 4000
|
||||
Total duration (hh:mm:ss): 05:08:09
|
||||
Speech duration (hh:mm:ss): 05:08:09 (100.0%)
|
||||
Duration statistics (seconds):
|
||||
mean 4.6
|
||||
std 3.8
|
||||
min 0.1
|
||||
25% 1.5
|
||||
50% 3.4
|
||||
75% 7.0
|
||||
99% 13.8
|
||||
99.5% 14.8
|
||||
99.9% 16.0
|
||||
max 17.3
|
||||
Recordings available: 4000
|
||||
Features available: 4000
|
||||
Supervisions available: 4000
|
||||
SUPERVISION custom fields:
|
||||
- fluent (in 4000 cuts)
|
||||
- symbol (in 4000 cuts)
|
||||
- disfluent (in 4000 cuts)
|
||||
- number (in 4000 cuts)
|
||||
---------------------------------
|
||||
|
||||
## train
|
||||
Cuts count: 1291134
|
||||
Total duration (hh:mm:ss): 1596:37:27
|
||||
Speech duration (hh:mm:ss): 1596:37:27 (100.0%)
|
||||
Duration statistics (seconds):
|
||||
mean 4.5
|
||||
std 3.6
|
||||
min 0.0
|
||||
25% 1.6
|
||||
50% 3.3
|
||||
75% 6.4
|
||||
99% 14.0
|
||||
99.5% 14.8
|
||||
99.9% 16.6
|
||||
max 27.8
|
||||
Recordings available: 1291134
|
||||
Features available: 1291134
|
||||
Supervisions available: 1291134
|
||||
csj_cuts_eval3.jsonl.gz:
|
||||
Cut statistics:
|
||||
╒═══════════════════════════╤══════════╕
|
||||
│ Cuts count: │ 865 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Total duration (hh:mm:ss) │ 01:26:44 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ mean │ 6.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ std │ 3.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ min │ 0.3 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 25% │ 3.3 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 50% │ 6.8 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 75% │ 8.7 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.5% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.9% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ max │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Recordings available: │ 865 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Features available: │ 0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Supervisions available: │ 865 │
|
||||
╘═══════════════════════════╧══════════╛
|
||||
SUPERVISION custom fields:
|
||||
- disfluent (in 1291134 cuts)
|
||||
- fluent (in 1291134 cuts)
|
||||
- symbol (in 1291134 cuts)
|
||||
- number (in 1291134 cuts)
|
||||
Speech duration statistics:
|
||||
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||
│ Total speech duration │ 01:26:44 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total speaking time duration │ 01:26:44 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||
|
||||
---------------------------------
|
||||
|
||||
csj_cuts_valid.jsonl.gz:
|
||||
Cut statistics:
|
||||
╒═══════════════════════════╤══════════╕
|
||||
│ Cuts count: │ 3743 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Total duration (hh:mm:ss) │ 06:40:15 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ mean │ 6.4 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ std │ 3.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ min │ 0.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 25% │ 3.9 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 50% │ 7.4 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 75% │ 9.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.5% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.9% │ 10.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ max │ 11.8 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Recordings available: │ 3743 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Features available: │ 0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Supervisions available: │ 3743 │
|
||||
╘═══════════════════════════╧══════════╛
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||
│ Total speech duration │ 06:40:15 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total speaking time duration │ 06:40:15 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||
|
||||
---------------------------------
|
||||
|
||||
csj_cuts_excluded.jsonl.gz:
|
||||
Cut statistics:
|
||||
╒═══════════════════════════╤══════════╕
|
||||
│ Cuts count: │ 980 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Total duration (hh:mm:ss) │ 00:56:06 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ mean │ 3.4 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ std │ 3.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ min │ 0.1 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 25% │ 0.8 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 50% │ 2.2 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 75% │ 5.8 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99% │ 9.9 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.5% │ 9.9 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ 99.9% │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ max │ 10.0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Recordings available: │ 980 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Features available: │ 0 │
|
||||
├───────────────────────────┼──────────┤
|
||||
│ Supervisions available: │ 980 │
|
||||
╘═══════════════════════════╧══════════╛
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||
│ Total speech duration │ 00:56:06 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total speaking time duration │ 00:56:06 │ 100.00% of recording │
|
||||
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||
|
||||
---------------------------------
|
||||
|
||||
csj_cuts_train.jsonl.gz:
|
||||
Cut statistics:
|
||||
╒═══════════════════════════╤════════════╕
|
||||
│ Cuts count: │ 914151 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ Total duration (hh:mm:ss) │ 1695:29:43 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ mean │ 6.7 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ std │ 2.9 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ min │ 0.1 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ 25% │ 4.6 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ 50% │ 7.5 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ 75% │ 8.9 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ 99% │ 11.0 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ 99.5% │ 11.0 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ 99.9% │ 11.1 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ max │ 18.0 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ Recordings available: │ 914151 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ Features available: │ 0 │
|
||||
├───────────────────────────┼────────────┤
|
||||
│ Supervisions available: │ 914151 │
|
||||
╘═══════════════════════════╧════════════╛
|
||||
SUPERVISION custom fields:
|
||||
Speech duration statistics:
|
||||
╒══════════════════════════════╤════════════╤══════════════════════╕
|
||||
│ Total speech duration │ 1695:29:43 │ 100.00% of recording │
|
||||
├──────────────────────────────┼────────────┼──────────────────────┤
|
||||
│ Total speaking time duration │ 1695:29:43 │ 100.00% of recording │
|
||||
├──────────────────────────────┼────────────┼──────────────────────┤
|
||||
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||
╘══════════════════════════════╧════════════╧══════════════════════╛
|
||||
"""
|
||||
|
@ -21,24 +21,14 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import CutSet
|
||||
from lhotse.recipes.csj import CSJSDBParser
|
||||
|
||||
ARGPARSE_DESCRIPTION = """
|
||||
This script gathers all training transcripts of the specified {trans_mode} type
|
||||
and produces a token_list that would be output set of the ASR system.
|
||||
This script gathers all training transcripts, parses them in disfluent mode, and produces a token list that would be the output set of the ASR system.
|
||||
|
||||
It splits transcripts by whitespace into lists, then, for each word in the
|
||||
list, if the word does not appear in the list of user-defined multicharacter
|
||||
strings, it further splits that word into individual characters to be counted
|
||||
into the output token set.
|
||||
|
||||
It outputs 4 files into the lang directory:
|
||||
- trans_mode: the name of transcript mode. If trans_mode was not specified,
|
||||
this will be an empty file.
|
||||
- userdef_string: a list of user defined strings that should not be split
|
||||
further into individual characters. By default, it contains "<unk>", "<blk>",
|
||||
"<sos/eos>"
|
||||
- words_len: the total number of tokens in the output set.
|
||||
- words.txt: a list of tokens in the output set. The length matches words_len.
|
||||
It outputs 3 files into the lang directory:
|
||||
- tokens.txt: a list of tokens in the output set.
|
||||
- lang_type: a file that contains the string "char"
|
||||
|
||||
"""
|
||||
|
||||
@ -50,98 +40,52 @@ def get_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--train-cut", type=Path, required=True, help="Path to the train cut"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--trans-mode",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Name of the transcript mode to use. "
|
||||
"If lang-dir is not set, this will also name the lang-dir"
|
||||
),
|
||||
"train_cut", metavar="train-cut", type=Path, help="Path to the train cut"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
default=Path("data/lang_char"),
|
||||
help=(
|
||||
"Name of lang dir. "
|
||||
"If not set, this will default to lang_char_{trans-mode}"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--userdef-string",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Multicharacter strings that do not need to be split",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
logging.basicConfig(
|
||||
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
if not args.lang_dir:
|
||||
p = "lang_char"
|
||||
if args.trans_mode:
|
||||
p += f"_{args.trans_mode}"
|
||||
args.lang_dir = Path(p)
|
||||
sysdef_string = set(["<blk>", "<unk>", "<sos/eos>"])
|
||||
|
||||
if args.userdef_string:
|
||||
args.userdef_string = set(args.userdef_string.read_text().split())
|
||||
else:
|
||||
args.userdef_string = set()
|
||||
# Using disfluent parsing as fluent is a subset of disfluent
|
||||
parser = CSJSDBParser()
|
||||
|
||||
sysdef_string = ["<blk>", "<unk>", "<sos/eos>"]
|
||||
args.userdef_string.update(sysdef_string)
|
||||
token_set = set()
|
||||
logging.info(f"Creating vocabulary from {args.train_cut}.")
|
||||
train_cut: CutSet = CutSet.from_file(args.train_cut)
|
||||
for cut in train_cut:
|
||||
if "_sp" in cut.id:
|
||||
continue
|
||||
|
||||
train_set: CutSet = CutSet.from_file(args.train_cut)
|
||||
|
||||
words = set()
|
||||
logging.info(
|
||||
f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode."
|
||||
)
|
||||
for cut in train_set:
|
||||
try:
|
||||
text: str = (
|
||||
cut.supervisions[0].custom[args.trans_mode]
|
||||
if args.trans_mode
|
||||
else cut.supervisions[0].text
|
||||
)
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}"
|
||||
)
|
||||
for t in text.split():
|
||||
if t in args.userdef_string:
|
||||
words.add(t)
|
||||
else:
|
||||
words.update(c for c in list(t))
|
||||
|
||||
words -= set(sysdef_string)
|
||||
words = sorted(words)
|
||||
words = ["<blk>"] + words + ["<unk>", "<sos/eos>"]
|
||||
text: str = cut.supervisions[0].custom["raw"]
|
||||
for w in parser.parse(text, sep=" ").split(" "):
|
||||
token_set.update(w)
|
||||
|
||||
token_set = ["<blk>"] + sorted(token_set - sysdef_string) + ["<unk>", "<sos/eos>"]
|
||||
args.lang_dir.mkdir(parents=True, exist_ok=True)
|
||||
(args.lang_dir / "words.txt").write_text(
|
||||
"\n".join(f"{word}\t{i}" for i, word in enumerate(words))
|
||||
(args.lang_dir / "tokens.txt").write_text(
|
||||
"\n".join(f"{t}\t{i}" for i, t in enumerate(token_set))
|
||||
)
|
||||
|
||||
(args.lang_dir / "words_len").write_text(f"{len(words)}")
|
||||
|
||||
(args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string))
|
||||
|
||||
(args.lang_dir / "trans_mode").write_text(args.trans_mode)
|
||||
(args.lang_dir / "lang_type").write_text("char")
|
||||
logging.info("Done.")
|
||||
|
||||
|
||||
|
462
egs/csj/ASR/local/utils/asr_datamodule.py
Normal file
462
egs/csj/ASR/local/utils/asr_datamodule.py
Normal file
@ -0,0 +1,462 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AsrVariableTranscriptDataset(K2SpeechRecognitionDataset):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
transcript_mode: str = "",
|
||||
return_cuts: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.transcript_mode = transcript_mode
|
||||
self.return_cuts = True
|
||||
self._return_cuts = return_cuts
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||
batch = super().__getitem__(cuts)
|
||||
|
||||
if self.transcript_mode:
|
||||
batch["supervisions"]["text"] = [
|
||||
supervision.custom[self.transcript_mode]
|
||||
for cut in batch["supervisions"]["cut"]
|
||||
for supervision in cut.supervisions
|
||||
]
|
||||
|
||||
if not self._return_cuts:
|
||||
del batch["supervisions"]["cut"]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class CSJAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--transcript-mode",
|
||||
type=str,
|
||||
default="",
|
||||
help="Mode of transcript in supervision to use.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--musan-dir", type=Path, help="Path to directory with musan cuts. "
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = AsrVariableTranscriptDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
transcript_mode=self.args.transcript_mode,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = AsrVariableTranscriptDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
transcript_mode=self.args.transcript_mode,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = AsrVariableTranscriptDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
transcript_mode=self.args.transcript_mode,
|
||||
)
|
||||
else:
|
||||
validate = AsrVariableTranscriptDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
transcript_mode=self.args.transcript_mode,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
|
||||
test = AsrVariableTranscriptDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
transcript_mode=self.args.transcript_mode,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_train.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get valid cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_valid.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def excluded_cuts(self) -> CutSet:
|
||||
logging.info("About to get excluded cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_excluded.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def eval1_cuts(self) -> CutSet:
|
||||
logging.info("About to get eval1 cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval1.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def eval2_cuts(self) -> CutSet:
|
||||
logging.info("About to get eval2 cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval2.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def eval3_cuts(self) -> CutSet:
|
||||
logging.info("About to get eval3 cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval3.jsonl.gz")
|
253
egs/csj/ASR/local/utils/tokenizer.py
Normal file
253
egs/csj/ASR/local/utils/tokenizer.py
Normal file
@ -0,0 +1,253 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
from k2 import SymbolTable
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
text2word: Callable[[str], List[str]]
|
||||
|
||||
@staticmethod
|
||||
def add_arguments(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(title="Lang related options")
|
||||
|
||||
group.add_argument("--lang", type=Path, help="Path to lang directory.")
|
||||
|
||||
group.add_argument(
|
||||
"--lang-type",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. "
|
||||
"Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor"
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def Load(lang_dir: Path, lang_type="", oov="<unk>"):
|
||||
|
||||
if not lang_type:
|
||||
assert (lang_dir / "lang_type").exists(), "lang_type not specified."
|
||||
lang_type = (lang_dir / "lang_type").read_text().strip()
|
||||
|
||||
tokenizer = None
|
||||
|
||||
if lang_type == "bpe":
|
||||
assert (
|
||||
lang_dir / "bpe.model"
|
||||
).exists(), f"No BPE .model could be found in {lang_dir}."
|
||||
tokenizer = spm.SentencePieceProcessor()
|
||||
tokenizer.Load(str(lang_dir / "bpe.model"))
|
||||
elif lang_type == "char":
|
||||
tokenizer = CharTokenizer(lang_dir, oov=oov)
|
||||
else:
|
||||
raise NotImplementedError(f"{lang_type} not supported at the moment.")
|
||||
|
||||
return tokenizer
|
||||
|
||||
load = Load
|
||||
|
||||
def PieceToId(self, piece: str) -> int:
|
||||
raise NotImplementedError(
|
||||
"You need to implement this function in the child class."
|
||||
)
|
||||
|
||||
piece_to_id = PieceToId
|
||||
|
||||
def IdToPiece(self, id: int) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to implement this function in the child class."
|
||||
)
|
||||
|
||||
id_to_piece = IdToPiece
|
||||
|
||||
def GetPieceSize(self) -> int:
|
||||
raise NotImplementedError(
|
||||
"You need to implement this function in the child class."
|
||||
)
|
||||
|
||||
get_piece_size = GetPieceSize
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_piece_size()
|
||||
|
||||
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
|
||||
raise NotImplementedError(
|
||||
"You need to implement this function in the child class."
|
||||
)
|
||||
|
||||
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
|
||||
raise NotImplementedError(
|
||||
"You need to implement this function in the child class."
|
||||
)
|
||||
|
||||
def EncodeAsIds(self, input: str) -> List[int]:
|
||||
return self.EncodeAsIdsBatch([input])[0]
|
||||
|
||||
def EncodeAsPieces(self, input: str) -> List[str]:
|
||||
return self.EncodeAsPiecesBatch([input])[0]
|
||||
|
||||
def Encode(
|
||||
self, input: Union[str, List[str]], out_type=int
|
||||
) -> Union[List, List[List]]:
|
||||
if not input:
|
||||
return []
|
||||
|
||||
if isinstance(input, list):
|
||||
if out_type is int:
|
||||
return self.EncodeAsIdsBatch(input)
|
||||
if out_type is str:
|
||||
return self.EncodeAsPiecesBatch(input)
|
||||
|
||||
if out_type is int:
|
||||
return self.EncodeAsIds(input)
|
||||
if out_type is str:
|
||||
return self.EncodeAsPieces(input)
|
||||
|
||||
encode = Encode
|
||||
|
||||
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
|
||||
raise NotImplementedError(
|
||||
"You need to implement this function in the child class."
|
||||
)
|
||||
|
||||
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
|
||||
raise NotImplementedError(
|
||||
"You need to implement this function in the child class."
|
||||
)
|
||||
|
||||
def DecodeIds(self, input: List[int]) -> str:
|
||||
return self.DecodeIdsBatch([input])[0]
|
||||
|
||||
def DecodePieces(self, input: List[str]) -> str:
|
||||
return self.DecodePiecesBatch([input])[0]
|
||||
|
||||
def Decode(
|
||||
self,
|
||||
input: Union[int, List[int], List[str], List[List[int]], List[List[str]]],
|
||||
) -> Union[List[str], str]:
|
||||
|
||||
if not input:
|
||||
return ""
|
||||
|
||||
if isinstance(input, int):
|
||||
return self.id_to_piece(input)
|
||||
elif isinstance(input, str):
|
||||
raise TypeError(
|
||||
"Unlike spm.SentencePieceProcessor, cannot decode from type str."
|
||||
)
|
||||
|
||||
if isinstance(input[0], list):
|
||||
if not input[0] or isinstance(input[0][0], int):
|
||||
return self.DecodeIdsBatch(input)
|
||||
|
||||
if isinstance(input[0][0], str):
|
||||
return self.DecodePiecesBatch(input)
|
||||
|
||||
if isinstance(input[0], int):
|
||||
return self.DecodeIds(input)
|
||||
if isinstance(input[0], str):
|
||||
return self.DecodePieces(input)
|
||||
|
||||
raise RuntimeError("Unknown input type")
|
||||
|
||||
decode = Decode
|
||||
|
||||
def SplitBatch(self, input: List[str]) -> List[List[str]]:
|
||||
raise NotImplementedError(
|
||||
"You need to implement this function in the child class."
|
||||
)
|
||||
|
||||
def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]:
|
||||
if isinstance(input, list):
|
||||
return self.SplitBatch(input)
|
||||
elif isinstance(input, str):
|
||||
return self.SplitBatch([input])[0]
|
||||
raise RuntimeError("Unknown input type")
|
||||
|
||||
split = Split
|
||||
|
||||
|
||||
class CharTokenizer(Tokenizer):
|
||||
def __init__(self, lang_dir: Path, oov="<unk>", sep=""):
|
||||
assert (
|
||||
lang_dir / "tokens.txt"
|
||||
).exists(), f"tokens.txt could not be found in {lang_dir}."
|
||||
token_table = SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
assert (
|
||||
"#0" not in token_table
|
||||
), "This tokenizer does not support disambig symbols."
|
||||
self._id2sym = token_table._id2sym
|
||||
self._sym2id = token_table._sym2id
|
||||
self.oov = oov
|
||||
self.oov_id = self._sym2id[oov]
|
||||
self.sep = sep
|
||||
if self.sep:
|
||||
self.text2word = lambda x: x.split(self.sep)
|
||||
else:
|
||||
self.text2word = lambda x: list(x.replace(" ", ""))
|
||||
|
||||
def piece_to_id(self, piece: str) -> int:
|
||||
try:
|
||||
return self._sym2id[piece]
|
||||
except KeyError:
|
||||
return self.oov_id
|
||||
|
||||
def id_to_piece(self, id: int) -> str:
|
||||
return self._id2sym[id]
|
||||
|
||||
def get_piece_size(self) -> int:
|
||||
return len(self._sym2id)
|
||||
|
||||
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
|
||||
return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input]
|
||||
|
||||
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
|
||||
return [
|
||||
[i if i in self._sym2id else self.oov for i in self.text2word(text)]
|
||||
for text in input
|
||||
]
|
||||
|
||||
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
|
||||
return [self.sep.join(self.id_to_piece(i) for i in text) for text in input]
|
||||
|
||||
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
|
||||
return [self.sep.join(text) for text in input]
|
||||
|
||||
def SplitBatch(self, input: List[str]) -> List[List[str]]:
|
||||
return [self.text2word(text) for text in input]
|
||||
|
||||
|
||||
def test_CharTokenizer():
|
||||
test_single_string = "こんにちは"
|
||||
test_multiple_string = [
|
||||
"今日はいい天気ですよね",
|
||||
"諏訪湖は綺麗でしょう",
|
||||
"这在词表外",
|
||||
"分かち 書き に し た 文章 です",
|
||||
"",
|
||||
]
|
||||
test_empty_string = ""
|
||||
sp = Tokenizer.load(Path("lang_char"), "char", oov="<unk>")
|
||||
splitter = sp.split
|
||||
print(sp.encode(test_single_string, out_type=str))
|
||||
print(sp.encode(test_single_string, out_type=int))
|
||||
print(sp.encode(test_multiple_string, out_type=str))
|
||||
print(sp.encode(test_multiple_string, out_type=int))
|
||||
print(sp.encode(test_empty_string, out_type=str))
|
||||
print(sp.encode(test_empty_string, out_type=int))
|
||||
print(sp.decode(sp.encode(test_single_string, out_type=str)))
|
||||
print(sp.decode(sp.encode(test_single_string, out_type=int)))
|
||||
print(sp.decode(sp.encode(test_multiple_string, out_type=str)))
|
||||
print(sp.decode(sp.encode(test_multiple_string, out_type=int)))
|
||||
print(sp.decode(sp.encode(test_empty_string, out_type=str)))
|
||||
print(sp.decode(sp.encode(test_empty_string, out_type=int)))
|
||||
print(splitter(test_single_string))
|
||||
print(splitter(test_multiple_string))
|
||||
print(splitter(test_empty_string))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_CharTokenizer()
|
@ -32,7 +32,7 @@
|
||||
# - speech
|
||||
#
|
||||
# By default, this script produces the original transcript like kaldi and espnet. Optionally, you
|
||||
# can generate other transcript formats by supplying your own config files. A few examples of these
|
||||
# can add other transcript formats by supplying your own config files. A few examples of these
|
||||
# config files can be found in local/conf.
|
||||
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
@ -44,10 +44,10 @@ nj=8
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
csj_dir=/mnt/minami_data_server/t2131178/corpus/CSJ
|
||||
musan_dir=/mnt/minami_data_server/t2131178/corpus/musan/musan
|
||||
trans_dir=$csj_dir/retranscript
|
||||
csj_fbank_dir=/mnt/host/csj_data/fbank
|
||||
csj_dir=/mnt/host/corpus/csj
|
||||
musan_dir=/mnt/host/corpus/musan/musan
|
||||
trans_dir=$csj_dir/transcript
|
||||
csj_fbank_dir=/mnt/host/corpus/csj/fbank
|
||||
musan_fbank_dir=$musan_dir/fbank
|
||||
csj_manifest_dir=data/manifests
|
||||
musan_manifest_dir=$musan_dir/manifests
|
||||
@ -63,12 +63,8 @@ log() {
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare CSJ manifest"
|
||||
# If you want to generate more transcript modes, append the path to those config files at c.
|
||||
# Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini
|
||||
# NOTE: In case multiple config files are supplied, the second config file and onwards will inherit
|
||||
# the segment boundaries of the first config file.
|
||||
if [ ! -e $csj_manifest_dir/.csj.done ]; then
|
||||
lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4
|
||||
lhotse prepare csj $csj_dir $csj_manifest_dir -t $trans_dir -j 16
|
||||
touch $csj_manifest_dir/.csj.done
|
||||
fi
|
||||
fi
|
||||
@ -88,32 +84,24 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \
|
||||
--fbank-dir $csj_fbank_dir
|
||||
parts=(
|
||||
train
|
||||
valid
|
||||
eval1
|
||||
eval2
|
||||
eval3
|
||||
valid
|
||||
excluded
|
||||
train
|
||||
)
|
||||
for part in ${parts[@]}; do
|
||||
python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz
|
||||
python local/validate_manifest.py --manifest $csj_fbank_dir/csj_cuts_$part.jsonl.gz
|
||||
done
|
||||
touch $csj_fbank_dir/.csj-validated.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Prepare CSJ lang"
|
||||
modes=disfluent
|
||||
|
||||
# If you want prepare the lang directory for other transcript modes, just append
|
||||
# the names of those modes behind. An example is shown as below:-
|
||||
# modes="$modes fluent symbol number"
|
||||
|
||||
for mode in ${modes[@]}; do
|
||||
python local/prepare_lang_char.py --trans-mode $mode \
|
||||
--train-cut $csj_manifest_dir/csj_cuts_train.jsonl.gz \
|
||||
--lang-dir lang_char_$mode
|
||||
done
|
||||
log "Stage 4: Prepare CSJ lang_char"
|
||||
python local/prepare_lang_char.py $csj_fbank_dir/csj_cuts_train.jsonl.gz
|
||||
python local/add_transcript_mode.py -f $csj_fbank_dir -c local/conf/fluent.ini local/conf/number.ini
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
@ -128,6 +116,6 @@ fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Show manifest statistics"
|
||||
python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt
|
||||
cat $csj_manifest_dir/manifest_statistics.txt
|
||||
python local/display_manifest_statistics.py --manifest-dir $csj_fbank_dir > $csj_fbank_dir/manifest_statistics.txt
|
||||
cat $csj_fbank_dir/manifest_statistics.txt
|
||||
fi
|
||||
|
@ -0,0 +1,76 @@
|
||||
import logging
|
||||
from configparser import ConfigParser
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def escape_html(text: str):
|
||||
"""
|
||||
Escapes all html characters in text
|
||||
:param str text:
|
||||
:rtype: str
|
||||
"""
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
|
||||
class TelegramStreamIO(logging.Handler):
|
||||
|
||||
API_ENDPOINT = "https://api.telegram.org"
|
||||
MAX_MESSAGE_LEN = 4096
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(levelname)s at %(funcName)s "
|
||||
"(line %(lineno)s):\n\n%(message)s"
|
||||
)
|
||||
|
||||
def __init__(self, tg_configfile: str):
|
||||
super(TelegramStreamIO, self).__init__()
|
||||
config = ConfigParser()
|
||||
if not config.read(tg_configfile):
|
||||
raise FileNotFoundError(
|
||||
f"{tg_configfile} not found. " "Retry without --telegram-cred flag."
|
||||
)
|
||||
config = config["TELEGRAM"]
|
||||
token = config["token"]
|
||||
self.chat_id = config["chat_id"]
|
||||
self.url = f"{self.API_ENDPOINT}/bot{token}/sendMessage"
|
||||
|
||||
@staticmethod
|
||||
def setup_logger(params):
|
||||
if not params.telegram_cred:
|
||||
return
|
||||
formatter = logging.Formatter(
|
||||
f"{params.exp_dir.name} %(asctime)s \n%(message)s"
|
||||
)
|
||||
tg = TelegramStreamIO(params.telegram_cred)
|
||||
tg.setLevel(logging.WARN)
|
||||
tg.setFormatter(formatter)
|
||||
logging.getLogger("").addHandler(tg)
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
"""
|
||||
Emit a record.
|
||||
Send the record to the Web server as a percent-encoded dictionary
|
||||
"""
|
||||
data = {
|
||||
"chat_id": self.chat_id,
|
||||
"text": self.format(self.mapLogRecord(record)),
|
||||
"parse_mode": "HTML",
|
||||
}
|
||||
try:
|
||||
requests.get(self.url, json=data)
|
||||
# return response.json()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to send telegram message: {repr(e)}")
|
||||
pass
|
||||
|
||||
def mapLogRecord(self, record):
|
||||
"""
|
||||
Default implementation of mapping the log record into a dict
|
||||
that is sent as the CGI data. Overwrite in your class.
|
||||
Contributed by Franz Glasner.
|
||||
"""
|
||||
|
||||
for k, v in record.__dict__.items():
|
||||
if isinstance(v, str):
|
||||
setattr(record, k, escape_html(v))
|
||||
return record
|
@ -0,0 +1 @@
|
||||
../local/utils/asr_datamodule.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py
|
852
egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
852
egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
@ -0,0 +1,852 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--lang data/lang_char \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method beam_search \
|
||||
--lang data/lang_char \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search \
|
||||
--lang data/lang_char \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--lang data/lang_char \
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--lang data/lang_char \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--lang data/lang_char \
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--lang data/lang_char \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import CSJAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from tokenizer import Tokenizer
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--res-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="The path to save results.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_char",
|
||||
help="The lang dir. It should contain at least a word table.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-graph",
|
||||
type=str,
|
||||
default="",
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pad-feature",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""
|
||||
Number of frames to pad at the end.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: Tokenizer,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.pad_feature:
|
||||
feature_lens += params.pad_feature
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.pad_feature),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(sp.text2word(hyp))
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.text2word(sp.decode(hyp)))
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: Tokenizer,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = sp.text2word(ref_text)
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
return test_set_wers
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
CSJAsrDataModule.add_arguments(parser)
|
||||
Tokenizer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
if not params.res_dir:
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", params.gpu)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> and <unk> are defined in local/prepare_lang_char.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
if params.decoding_graph:
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(params.decoding_graph, map_location=device)
|
||||
)
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
csj_corpus = CSJAsrDataModule(args)
|
||||
|
||||
for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]:
|
||||
results_dict = decode_dataset(
|
||||
dl=csj_corpus.test_dataloaders(getattr(csj_corpus, f"{subdir}_cuts")()),
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
tot_err = save_results(
|
||||
params=params,
|
||||
test_set_name=subdir,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
with (
|
||||
params.res_dir
|
||||
/ (
|
||||
f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
|
||||
f"_{params.avg}_{params.epoch}.cer"
|
||||
)
|
||||
).open("w") as fout:
|
||||
if len(tot_err) == 1:
|
||||
fout.write(f"{tot_err[0][1]}")
|
||||
else:
|
||||
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
|
313
egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py
Normal file
313
egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py
Normal file
@ -0,0 +1,313 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--lang data/lang_char \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--lang data/lang_char \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/csj/ASR
|
||||
./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--lang data/lang_char
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
Note: If you don't want to train a model from scratch, we have
|
||||
provided one for you. You can get it at
|
||||
|
||||
https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208
|
||||
|
||||
with the following commands:
|
||||
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208
|
||||
# You will find the pre-trained model in icefall-asr-csj-pruned-transducer-stateless7-230208/exp_fluent
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from tokenizer import Tokenizer
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_streaming/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
Tokenizer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> is defined in local/prepare_lang_char.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1,308 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Usage:
|
||||
# use -O to skip assertions and avoid some of the TracerWarnings
|
||||
python -O pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--lang data/lang_char \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--use-averaged-model=True \
|
||||
--decode-chunk-len 32
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from tokenizer import Tokenizer
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for averaging.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def export_encoder_model_jit_trace(
|
||||
encoder_model: torch.nn.Module,
|
||||
encoder_filename: str,
|
||||
params: AttributeDict,
|
||||
) -> None:
|
||||
"""Export the given encoder model with torch.jit.trace()
|
||||
|
||||
Note: The warmup argument is fixed to 1.
|
||||
|
||||
Args:
|
||||
encoder_model:
|
||||
The input encoder model
|
||||
encoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
decode_chunk_len = params.decode_chunk_len # before subsampling
|
||||
pad_length = 7
|
||||
s = f"decode_chunk_len: {decode_chunk_len}"
|
||||
logging.info(s)
|
||||
assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
|
||||
encoder_model.decode_chunk_size,
|
||||
decode_chunk_len,
|
||||
)
|
||||
|
||||
T = decode_chunk_len + pad_length
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
states = encoder_model.get_init_state(device=x.device)
|
||||
|
||||
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
||||
traced_model.save(encoder_filename)
|
||||
logging.info(f"Saved to {encoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_jit_trace(
|
||||
decoder_model: torch.nn.Module,
|
||||
decoder_filename: str,
|
||||
) -> None:
|
||||
"""Export the given decoder model with torch.jit.trace()
|
||||
|
||||
Note: The argument need_pad is fixed to False.
|
||||
|
||||
Args:
|
||||
decoder_model:
|
||||
The input decoder model
|
||||
decoder_filename:
|
||||
The filename to save the exported model.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||
traced_model.save(decoder_filename)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_jit_trace(
|
||||
joiner_model: torch.nn.Module,
|
||||
joiner_filename: str,
|
||||
) -> None:
|
||||
"""Export the given joiner model with torch.jit.trace()
|
||||
|
||||
Note: The argument project_input is fixed to True. A user should not
|
||||
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||
will do that for the user.
|
||||
|
||||
Args:
|
||||
joiner_model:
|
||||
The input joiner model
|
||||
joiner_filename:
|
||||
The filename to save the exported model.
|
||||
|
||||
"""
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||
traced_model.save(joiner_filename)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
Tokenizer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> is defined in local/prepare_lang_char.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
logging.info("Using torch.jit.trace()")
|
||||
|
||||
logging.info("Exporting encoder")
|
||||
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||
export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
|
||||
|
||||
logging.info("Exporting decoder")
|
||||
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||
|
||||
logging.info("Exporting joiner")
|
||||
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
# flake8: noqa
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models exported by `torch.jit.trace()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--lang data/lang_char \
|
||||
--epoch 30 \
|
||||
--avg 10 \
|
||||
--use-averaged-model=True \
|
||||
--decode-chunk-len 32
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
|
||||
--encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \
|
||||
--decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \
|
||||
--joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \
|
||||
--lang data/lang_char \
|
||||
--decode-chunk-len 32 \
|
||||
/path/to/foo.wav \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the encoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the decoder torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the joiner torchscript model. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-len",
|
||||
type=int,
|
||||
default=32,
|
||||
help="The chunk size for decoding (in frames before subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
decoder: torch.jit.ScriptModule,
|
||||
joiner: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: Optional[torch.Tensor] = None,
|
||||
hyp: Optional[List[int]] = None,
|
||||
):
|
||||
assert encoder_out.ndim == 2
|
||||
context_size = 2
|
||||
blank_id = 0
|
||||
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)
|
||||
# decoder_input.shape (1,, 1 context_size)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
else:
|
||||
assert decoder_out.ndim == 2
|
||||
assert hyp is not None, hyp
|
||||
|
||||
T = encoder_out.size(0)
|
||||
for i in range(T):
|
||||
cur_encoder_out = encoder_out[i : i + 1]
|
||||
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||
y = joiner_out.argmax(dim=0).item()
|
||||
|
||||
if y != blank_id:
|
||||
hyp.append(y)
|
||||
decoder_input = hyp[-context_size:]
|
||||
|
||||
decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
|
||||
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
|
||||
"""Create a CPU streaming feature extractor.
|
||||
|
||||
At present, we assume it returns a fbank feature extractor with
|
||||
fixed options. In the future, we will support passing in the options
|
||||
from outside.
|
||||
|
||||
Returns:
|
||||
Return a CPU streaming feature extractor.
|
||||
"""
|
||||
opts = FbankOptions()
|
||||
opts.device = "cpu"
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
return OnlineFbank(opts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
Tokenizer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
encoder = torch.jit.load(args.encoder_model_filename)
|
||||
decoder = torch.jit.load(args.decoder_model_filename)
|
||||
joiner = torch.jit.load(args.joiner_model_filename)
|
||||
|
||||
encoder.eval()
|
||||
decoder.eval()
|
||||
joiner.eval()
|
||||
|
||||
encoder.to(device)
|
||||
decoder.to(device)
|
||||
joiner.to(device)
|
||||
|
||||
sp = Tokenizer.load(args.lang, args.lang_type)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
online_fbank = create_streaming_feature_extractor(args.sample_rate)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_file}")
|
||||
wave_samples = read_sound_files(
|
||||
filenames=[args.sound_file],
|
||||
expected_sample_rate=args.sample_rate,
|
||||
)[0]
|
||||
logging.info(wave_samples.shape)
|
||||
|
||||
logging.info("Decoding started")
|
||||
chunk_length = args.decode_chunk_len
|
||||
assert encoder.decode_chunk_size == chunk_length // 2, (
|
||||
encoder.decode_chunk_size,
|
||||
chunk_length,
|
||||
)
|
||||
|
||||
# we subsample features with ((x_len - 7) // 2 + 1) // 2
|
||||
pad_length = 7
|
||||
T = chunk_length + pad_length
|
||||
|
||||
logging.info(f"chunk_length: {chunk_length}")
|
||||
|
||||
states = encoder.get_init_state(device)
|
||||
|
||||
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
|
||||
|
||||
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||
|
||||
chunk = int(0.25 * args.sample_rate) # 0.2 second
|
||||
num_processed_frames = 0
|
||||
|
||||
hyp = None
|
||||
decoder_out = None
|
||||
|
||||
start = 0
|
||||
while start < wave_samples.numel():
|
||||
logging.info(f"{start}/{wave_samples.numel()}")
|
||||
end = min(start + chunk, wave_samples.numel())
|
||||
samples = wave_samples[start:end]
|
||||
start += chunk
|
||||
online_fbank.accept_waveform(
|
||||
sampling_rate=args.sample_rate,
|
||||
waveform=samples,
|
||||
)
|
||||
while online_fbank.num_frames_ready - num_processed_frames >= T:
|
||||
frames = []
|
||||
for i in range(T):
|
||||
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||
frames = torch.cat(frames, dim=0).unsqueeze(0)
|
||||
x_lens = torch.tensor([T], dtype=torch.int32)
|
||||
encoder_out, out_lens, states = encoder(
|
||||
x=frames,
|
||||
x_lens=x_lens,
|
||||
states=states,
|
||||
)
|
||||
num_processed_frames += chunk_length
|
||||
|
||||
hyp, decoder_out = greedy_search(
|
||||
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp
|
||||
)
|
||||
|
||||
context_size = 2
|
||||
logging.info(args.sound_file)
|
||||
logging.info(sp.decode(hyp[context_size:]))
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
torch.set_num_threads(4)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py
|
347
egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py
Normal file
347
egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py
Normal file
@ -0,0 +1,347 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7_streaming/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--lang data/lang_char \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--lang data/lang_char \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--lang data/lang_char \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--lang data/lang_char \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||
--lang data/lang_char \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by
|
||||
./pruned_transducer_stateless7_streaming/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from tokenizer import Tokenizer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
Tokenizer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> is defined in local/prepare_lang_char.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py
|
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
|
597
egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
597
egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
@ -0,0 +1,597 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--decoding_method greedy_search \
|
||||
--lang data/lang_char \
|
||||
--num-decode-streams 2000
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import CSJAsrDataModule
|
||||
from decode import save_results
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from tokenizer import Tokenizer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import stack_states, unstack_states
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Supported decoding methods are:
|
||||
greedy_search
|
||||
modified_beam_search
|
||||
fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-graph",
|
||||
type=str,
|
||||
default="",
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=32,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--res-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="The path to save results.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
decode_streams: List[DecodeStream],
|
||||
) -> List[int]:
|
||||
"""Decode one chunk frames of features for each decode_streams and
|
||||
return the indexes of finished streams in a List.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decode_streams:
|
||||
A List of DecodeStream, each belonging to a utterance.
|
||||
Returns:
|
||||
Return a List containing which DecodeStreams are finished.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
|
||||
# factor in encoders is 8.
|
||||
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
|
||||
tail_length = 23
|
||||
if features.size(1) < tail_length:
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, pad_length),
|
||||
mode="constant",
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
states = stack_states(states)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
processed_lens=processed_lens,
|
||||
streams=decode_streams,
|
||||
beam=params.beam,
|
||||
max_states=params.max_states,
|
||||
max_contexts=params.max_contexts,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
|
||||
states = unstack_states(new_states)
|
||||
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = states[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
|
||||
return finished_streams
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: Tokenizer,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
cuts:
|
||||
Lhotse Cutset containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
opts = FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 50
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
initial_states = model.encoder.get_init_state(device=device)
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
initial_states=initial_states,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
)
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
|
||||
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
|
||||
|
||||
decode_streams.append(decode_stream)
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
sp.text2word(decode_streams[i].ground_truth),
|
||||
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
sp.text2word(decode_streams[i].ground_truth),
|
||||
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
CSJAsrDataModule.add_arguments(parser)
|
||||
Tokenizer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if not params.res_dir:
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", params.gpu)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> and <unk> is defined in local/prepare_lang_char.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
decoding_graph = None
|
||||
if params.decoding_graph:
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(params.decoding_graph, map_location=device)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
args.return_cuts = True
|
||||
csj_corpus = CSJAsrDataModule(args)
|
||||
|
||||
for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]:
|
||||
results_dict = decode_dataset(
|
||||
cuts=getattr(csj_corpus, f"{subdir}_cuts")(),
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
tot_err = save_results(
|
||||
params=params, test_set_name=subdir, results_dict=results_dict
|
||||
)
|
||||
|
||||
with (
|
||||
params.res_dir
|
||||
/ (
|
||||
f"{subdir}-{params.decode_chunk_len}"
|
||||
f"_{params.avg}_{params.epoch}.cer"
|
||||
)
|
||||
).open("w") as fout:
|
||||
if len(tot_err) == 1:
|
||||
fout.write(f"{tot_err[0][1]}")
|
||||
else:
|
||||
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
150
egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py
Executable file
150
egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py
Executable file
@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/csj/ASR
|
||||
python ./pruned_transducer_stateless7_streaming/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
params.num_left_chunks = 4
|
||||
params.short_chunk_size = 50
|
||||
params.decode_chunk_len = 32
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Test jit script
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
print("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
|
||||
|
||||
def test_model_jit_trace():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
params.num_left_chunks = 4
|
||||
params.short_chunk_size = 50
|
||||
params.decode_chunk_len = 32
|
||||
model = get_transducer_model(params)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
# Test encoder
|
||||
def _test_encoder():
|
||||
encoder = model.encoder
|
||||
assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
T = params.decode_chunk_len + 7
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
states = encoder.get_init_state(device=x.device)
|
||||
encoder.__class__.forward = encoder.__class__.streaming_forward
|
||||
traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
|
||||
|
||||
states1 = encoder.get_init_state(device=x.device)
|
||||
states2 = traced_encoder.get_init_state(device=x.device)
|
||||
for i in range(5):
|
||||
x = torch.randn(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
|
||||
y2, _, states2 = traced_encoder(x, x_lens, states2)
|
||||
assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
|
||||
|
||||
# Test decoder
|
||||
def _test_decoder():
|
||||
decoder = model.decoder
|
||||
y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_decoder = torch.jit.trace(decoder, (y, need_pad))
|
||||
d1 = decoder(y, need_pad)
|
||||
d2 = traced_decoder(y, need_pad)
|
||||
assert torch.equal(d1, d2), (d1 - d2).abs().mean()
|
||||
|
||||
# Test joiner
|
||||
def _test_joiner():
|
||||
joiner = model.joiner
|
||||
encoder_out_dim = joiner.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
|
||||
j1 = joiner(encoder_out, decoder_out)
|
||||
j2 = traced_joiner(encoder_out, decoder_out)
|
||||
assert torch.equal(j1, j2), (j1 - j2).abs().mean()
|
||||
|
||||
_test_encoder()
|
||||
_test_decoder()
|
||||
_test_joiner()
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
test_model_jit_trace()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../local/utils/tokenizer.py
|
1304
egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
1304
egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
|
Loading…
x
Reference in New Issue
Block a user