mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
CSJ pruned_transducer_stateless7_streaming (#892)
* update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * working train * Add compare_cer_transcript.py * fix tokenizer decode, allow d2f only * comment cleanup * add export files and READMEs * reword average column * fix comments * Update new results
This commit is contained in:
parent
25ee50e27c
commit
e63a8c27f8
11
egs/csj/ASR/README.md
Normal file
11
egs/csj/ASR/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# Introduction
|
||||||
|
|
||||||
|
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||||
|
|
||||||
|
# Transducers
|
||||||
|
|
||||||
|
These are the types of architectures currently available.
|
||||||
|
|
||||||
|
| | Encoder | Decoder | Comment |
|
||||||
|
|---------------------------------------|---------------------|--------------------|---------------------------------------------------|
|
||||||
|
| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming |
|
200
egs/csj/ASR/RESULTS.md
Normal file
200
egs/csj/ASR/RESULTS.md
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
# Results
|
||||||
|
|
||||||
|
## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer)
|
||||||
|
|
||||||
|
### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
|
||||||
|
|
||||||
|
See <https://github.com/k2-fsa/icefall/pull/892> for more details.
|
||||||
|
|
||||||
|
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||||
|
<https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208>
|
||||||
|
|
||||||
|
Number of model parameters: 75688409, i.e. 75.7M.
|
||||||
|
|
||||||
|
#### training on disfluent transcript
|
||||||
|
|
||||||
|
The CERs are:
|
||||||
|
|
||||||
|
| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode |
|
||||||
|
| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- |
|
||||||
|
| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming |
|
||||||
|
| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise |
|
||||||
|
| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming |
|
||||||
|
| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise |
|
||||||
|
| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming |
|
||||||
|
| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise |
|
||||||
|
| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming |
|
||||||
|
| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise |
|
||||||
|
| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming |
|
||||||
|
| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise |
|
||||||
|
| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming |
|
||||||
|
| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise |
|
||||||
|
|
||||||
|
Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
|
||||||
|
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
|
||||||
|
|
||||||
|
The training command was:
|
||||||
|
```bash
|
||||||
|
./pruned_transducer_stateless7_streaming/train.py \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
|
||||||
|
--max-duration 375 \
|
||||||
|
--transcript-mode disfluent \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||||
|
--pad-feature 30 \
|
||||||
|
--musan-dir /mnt/host/corpus/musan/musan/fbank
|
||||||
|
```
|
||||||
|
|
||||||
|
The simulated streaming decoding command was:
|
||||||
|
```bash
|
||||||
|
for chunk in 64 32; do
|
||||||
|
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
python pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 17 \
|
||||||
|
--max-duration 350 \
|
||||||
|
--decoding-method $m \
|
||||||
|
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--transcript-mode disfluent \
|
||||||
|
--res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \
|
||||||
|
--decode-chunk-len $chunk \
|
||||||
|
--pad-feature 30 \
|
||||||
|
--gpu 0
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
The streaming chunk-wise decoding command was:
|
||||||
|
```bash
|
||||||
|
for chunk in 64 32; do
|
||||||
|
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
python pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 17 \
|
||||||
|
--max-duration 350 \
|
||||||
|
--decoding-method $m \
|
||||||
|
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--transcript-mode disfluent \
|
||||||
|
--res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \
|
||||||
|
--decode-chunk-len $chunk \
|
||||||
|
--gpu 2 \
|
||||||
|
--num-decode-streams 40
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
#### training on fluent transcript
|
||||||
|
|
||||||
|
The CERs are:
|
||||||
|
|
||||||
|
| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode |
|
||||||
|
| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- |
|
||||||
|
| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming |
|
||||||
|
| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise |
|
||||||
|
| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming |
|
||||||
|
| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise |
|
||||||
|
| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming |
|
||||||
|
| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise |
|
||||||
|
| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming |
|
||||||
|
| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise |
|
||||||
|
| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming |
|
||||||
|
| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise |
|
||||||
|
| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming |
|
||||||
|
| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise |
|
||||||
|
|
||||||
|
Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
|
||||||
|
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.
|
||||||
|
|
||||||
|
The training command was:
|
||||||
|
```bash
|
||||||
|
./pruned_transducer_stateless7_streaming/train.py \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
|
||||||
|
--max-duration 375 \
|
||||||
|
--transcript-mode fluent \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||||
|
--pad-feature 30 \
|
||||||
|
--musan-dir /mnt/host/corpus/musan/musan/fbank
|
||||||
|
```
|
||||||
|
|
||||||
|
The simulated streaming decoding command was:
|
||||||
|
```bash
|
||||||
|
for chunk in 64 32; do
|
||||||
|
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
python pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 12 \
|
||||||
|
--max-duration 350 \
|
||||||
|
--decoding-method $m \
|
||||||
|
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--transcript-mode fluent \
|
||||||
|
--res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \
|
||||||
|
--decode-chunk-len $chunk \
|
||||||
|
--pad-feature 30 \
|
||||||
|
--gpu 1
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
The streaming chunk-wise decoding command was:
|
||||||
|
```bash
|
||||||
|
for chunk in 64 32; do
|
||||||
|
for m in greedy_search fast_beam_search modified_beam_search; do
|
||||||
|
python pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||||
|
--feedforward-dims "1024,1024,2048,2048,1024" \
|
||||||
|
--exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 12 \
|
||||||
|
--max-duration 350 \
|
||||||
|
--decoding-method $m \
|
||||||
|
--manifest-dir /mnt/host/corpus/csj/fbank \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--transcript-mode fluent \
|
||||||
|
--res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \
|
||||||
|
--decode-chunk-len $chunk \
|
||||||
|
--gpu 3 \
|
||||||
|
--num-decode-streams 40
|
||||||
|
done
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Comparing disfluent to fluent
|
||||||
|
|
||||||
|
$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$
|
||||||
|
|
||||||
|
This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared.
|
||||||
|
|
||||||
|
| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode |
|
||||||
|
| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- |
|
||||||
|
| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming |
|
||||||
|
| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise |
|
||||||
|
| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming |
|
||||||
|
| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise |
|
||||||
|
| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming |
|
||||||
|
| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise |
|
||||||
|
| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming |
|
||||||
|
| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise |
|
||||||
|
| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming |
|
||||||
|
| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise |
|
||||||
|
| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming |
|
||||||
|
| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise |
|
||||||
|
| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | |
|
94
egs/csj/ASR/local/add_transcript_mode.py
Normal file
94
egs/csj/ASR/local/add_transcript_mode.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from configparser import ConfigParser
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from lhotse import CutSet, SupervisionSet
|
||||||
|
from lhotse.recipes.csj import CSJSDBParser
|
||||||
|
|
||||||
|
ARGPARSE_DESCRIPTION = """
|
||||||
|
This script adds transcript modes to an existing CutSet or SupervisionSet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
description=ARGPARSE_DESCRIPTION,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--fbank-dir",
|
||||||
|
type=Path,
|
||||||
|
help="Path to directory where manifests are stored.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-c",
|
||||||
|
"--config",
|
||||||
|
type=Path,
|
||||||
|
nargs="+",
|
||||||
|
help="Path to config file for transcript parsing.",
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]:
|
||||||
|
parsers = []
|
||||||
|
for config_file in config_files:
|
||||||
|
config = ConfigParser()
|
||||||
|
config.optionxform = str
|
||||||
|
assert config.read(config_file), f"{config_file} could not be found."
|
||||||
|
decisions = {}
|
||||||
|
for k, v in config["DECISIONS"].items():
|
||||||
|
try:
|
||||||
|
decisions[k] = int(v)
|
||||||
|
except ValueError:
|
||||||
|
decisions[k] = v
|
||||||
|
parsers.append(
|
||||||
|
(config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions))
|
||||||
|
)
|
||||||
|
return parsers
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.basicConfig(
|
||||||
|
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
parsers = get_CSJParsers(args.config)
|
||||||
|
config = ConfigParser()
|
||||||
|
config.optionxform = str
|
||||||
|
assert config.read(args.config), args.config
|
||||||
|
decisions = {}
|
||||||
|
for k, v in config["DECISIONS"].items():
|
||||||
|
try:
|
||||||
|
decisions[k] = int(v)
|
||||||
|
except ValueError:
|
||||||
|
decisions[k] = v
|
||||||
|
|
||||||
|
logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.")
|
||||||
|
|
||||||
|
manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz")
|
||||||
|
assert manifests, f"No cuts to be found in {args.fbank_dir}"
|
||||||
|
|
||||||
|
for manifest in manifests:
|
||||||
|
results = []
|
||||||
|
logging.info(f"Adding transcript modes to {manifest.name} now.")
|
||||||
|
cutset = CutSet.from_file(manifest)
|
||||||
|
for cut in cutset:
|
||||||
|
for name, parser in parsers:
|
||||||
|
cut.supervisions[0].custom[name] = parser.parse(
|
||||||
|
cut.supervisions[0].custom["raw"]
|
||||||
|
)
|
||||||
|
cut.supervisions[0].text = ""
|
||||||
|
results.append(cut)
|
||||||
|
results = CutSet.from_items(results)
|
||||||
|
res_file = manifest.as_posix()
|
||||||
|
manifest.replace(manifest.parent / ("bak." + manifest.name))
|
||||||
|
results.to_file(res_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
|
# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -19,9 +19,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from itertools import islice
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from random import Random
|
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -35,20 +33,10 @@ from lhotse import ( # See the following for why LilcomChunkyWriter is preferre
|
|||||||
RecordingSet,
|
RecordingSet,
|
||||||
SupervisionSet,
|
SupervisionSet,
|
||||||
)
|
)
|
||||||
|
from lhotse.recipes.csj import concat_csj_supervisions
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
ARGPARSE_DESCRIPTION = """
|
|
||||||
This script follows the espnet method of splitting the remaining core+noncore
|
|
||||||
utterances into valid and train cutsets at an index which is by default 4000.
|
|
||||||
|
|
||||||
In other words, the core+noncore utterances are shuffled, where 4000 utterances
|
|
||||||
of the shuffled set go to the `valid` cutset and are not subject to speed
|
|
||||||
perturbation. The remaining utterances become the `train` cutset and are speed-
|
|
||||||
perturbed (0.9x, 1.0x, 1.1x).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
# Do this outside of main() in case it needs to take effect
|
# Do this outside of main() in case it needs to take effect
|
||||||
@ -57,66 +45,101 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
RNG_SEED = 42
|
RNG_SEED = 42
|
||||||
|
# concat_params_train = [
|
||||||
|
# {"gap": 1.0, "maxlen": 10.0},
|
||||||
|
# {"gap": 1.5, "maxlen": 8.0},
|
||||||
|
# {"gap": 1.0, "maxlen": 18.0},
|
||||||
|
# ]
|
||||||
|
|
||||||
|
concat_params = {"gap": 1.0, "maxlen": 10.0}
|
||||||
|
|
||||||
|
|
||||||
def make_cutset_blueprints(
|
def make_cutset_blueprints(
|
||||||
manifest_dir: Path,
|
manifest_dir: Path,
|
||||||
split: int,
|
|
||||||
) -> List[Tuple[str, CutSet]]:
|
) -> List[Tuple[str, CutSet]]:
|
||||||
|
|
||||||
cut_sets = []
|
cut_sets = []
|
||||||
|
logging.info("Creating non-train cuts.")
|
||||||
|
|
||||||
# Create eval datasets
|
# Create eval datasets
|
||||||
logging.info("Creating eval cuts.")
|
|
||||||
for i in range(1, 4):
|
for i in range(1, 4):
|
||||||
|
sps = sorted(
|
||||||
|
SupervisionSet.from_file(
|
||||||
|
manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz"
|
||||||
|
),
|
||||||
|
key=lambda x: x.id,
|
||||||
|
)
|
||||||
|
|
||||||
cut_set = CutSet.from_manifests(
|
cut_set = CutSet.from_manifests(
|
||||||
recordings=RecordingSet.from_file(
|
recordings=RecordingSet.from_file(
|
||||||
manifest_dir / f"csj_recordings_eval{i}.jsonl.gz"
|
manifest_dir / f"csj_recordings_eval{i}.jsonl.gz"
|
||||||
),
|
),
|
||||||
supervisions=SupervisionSet.from_file(
|
supervisions=concat_csj_supervisions(sps, **concat_params),
|
||||||
manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
||||||
cut_sets.append((f"eval{i}", cut_set))
|
cut_sets.append((f"eval{i}", cut_set))
|
||||||
|
|
||||||
# Create train and valid cuts
|
# Create excluded dataset
|
||||||
logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.")
|
sps = sorted(
|
||||||
recording_set = RecordingSet.from_file(
|
SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"),
|
||||||
manifest_dir / "csj_recordings_core.jsonl.gz"
|
key=lambda x: x.id,
|
||||||
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
|
)
|
||||||
supervision_set = SupervisionSet.from_file(
|
|
||||||
manifest_dir / "csj_supervisions_core.jsonl.gz"
|
|
||||||
) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz")
|
|
||||||
|
|
||||||
cut_set = CutSet.from_manifests(
|
cut_set = CutSet.from_manifests(
|
||||||
recordings=recording_set,
|
recordings=RecordingSet.from_file(
|
||||||
supervisions=supervision_set,
|
manifest_dir / "csj_recordings_excluded.jsonl.gz"
|
||||||
|
),
|
||||||
|
supervisions=concat_csj_supervisions(sps, **concat_params),
|
||||||
)
|
)
|
||||||
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
||||||
cut_set = cut_set.shuffle(Random(RNG_SEED))
|
cut_sets.append(("excluded", cut_set))
|
||||||
|
|
||||||
logging.info(
|
# Create valid dataset
|
||||||
"Creating valid and train cuts from core and noncore, split at {split}."
|
sps = sorted(
|
||||||
|
SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"),
|
||||||
|
key=lambda x: x.id,
|
||||||
)
|
)
|
||||||
valid_set = CutSet.from_cuts(islice(cut_set, 0, split))
|
cut_set = CutSet.from_manifests(
|
||||||
|
recordings=RecordingSet.from_file(
|
||||||
|
manifest_dir / "csj_recordings_valid.jsonl.gz"
|
||||||
|
),
|
||||||
|
supervisions=concat_csj_supervisions(sps, **concat_params),
|
||||||
|
)
|
||||||
|
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
||||||
|
cut_sets.append(("valid", cut_set))
|
||||||
|
|
||||||
train_set = CutSet.from_cuts(islice(cut_set, split, None))
|
logging.info("Creating train cuts.")
|
||||||
|
|
||||||
|
# Create train dataset
|
||||||
|
sps = sorted(
|
||||||
|
SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz")
|
||||||
|
+ SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"),
|
||||||
|
key=lambda x: x.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
recording = RecordingSet.from_file(
|
||||||
|
manifest_dir / "csj_recordings_core.jsonl.gz"
|
||||||
|
) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz")
|
||||||
|
|
||||||
|
train_set = CutSet.from_manifests(
|
||||||
|
recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params)
|
||||||
|
).trim_to_supervisions(keep_overlapping=False)
|
||||||
train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
|
train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1)
|
||||||
|
|
||||||
cut_sets.extend([("valid", valid_set), ("train", train_set)])
|
cut_sets.append(("train", train_set))
|
||||||
|
|
||||||
return cut_sets
|
return cut_sets
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description=ARGPARSE_DESCRIPTION,
|
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
|
"-m", "--manifest-dir", type=Path, help="Path to save manifests"
|
||||||
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
|
)
|
||||||
parser.add_argument("--split", type=int, default=4000, help="Split at this index")
|
parser.add_argument(
|
||||||
|
"-f", "--fbank-dir", type=Path, help="Path to save fbank features"
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -138,7 +161,7 @@ def main():
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
cut_sets = make_cutset_blueprints(args.manifest_dir, args.split)
|
cut_sets = make_cutset_blueprints(args.manifest_dir)
|
||||||
for part, cut_set in cut_sets:
|
for part, cut_set in cut_sets:
|
||||||
logging.info(f"Processing {part}")
|
logging.info(f"Processing {part}")
|
||||||
cut_set = cut_set.compute_and_store_features(
|
cut_set = cut_set.compute_and_store_features(
|
||||||
@ -147,7 +170,7 @@ def main():
|
|||||||
storage_path=(args.fbank_dir / f"feats_{part}").as_posix(),
|
storage_path=(args.fbank_dir / f"feats_{part}").as_posix(),
|
||||||
storage_type=LilcomChunkyWriter,
|
storage_type=LilcomChunkyWriter,
|
||||||
)
|
)
|
||||||
cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz")
|
cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz")
|
||||||
|
|
||||||
logging.info("All fbank computed for CSJ.")
|
logging.info("All fbank computed for CSJ.")
|
||||||
(args.fbank_dir / ".done").touch()
|
(args.fbank_dir / ".done").touch()
|
||||||
|
@ -28,9 +28,7 @@ from icefall.utils import get_executor
|
|||||||
|
|
||||||
ARGPARSE_DESCRIPTION = """
|
ARGPARSE_DESCRIPTION = """
|
||||||
This file computes fbank features of the musan dataset.
|
This file computes fbank features of the musan dataset.
|
||||||
It looks for manifests in the directory data/manifests.
|
|
||||||
|
|
||||||
The generated fbank features are saved in data/fbank.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
@ -42,8 +40,6 @@ torch.set_num_interop_threads(1)
|
|||||||
|
|
||||||
|
|
||||||
def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
|
def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path):
|
||||||
# src_dir = Path("data/manifests")
|
|
||||||
# output_dir = Path("data/fbank")
|
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
num_mel_bins = 80
|
num_mel_bins = 80
|
||||||
|
|
||||||
@ -104,8 +100,12 @@ def get_args():
|
|||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests")
|
parser.add_argument(
|
||||||
parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features")
|
"-m", "--manifest-dir", type=Path, help="Path to save manifests"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-f", "--fbank-dir", type=Path, help="Path to save fbank features"
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -1,320 +1,79 @@
|
|||||||
; # This section is ignored if this file is not supplied as the first config file to
|
|
||||||
; # lhotse prepare csj
|
|
||||||
[SEGMENTS]
|
|
||||||
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
|
|
||||||
gap = 0.5
|
|
||||||
; # Maximum length of segment (s).
|
|
||||||
maxlen = 10
|
|
||||||
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
|
|
||||||
minlen = 0.02
|
|
||||||
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
|
|
||||||
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
|
|
||||||
; # If you intend to use a multicharacter string for gap_sym, remember to register the
|
|
||||||
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
|
|
||||||
gap_sym =
|
|
||||||
|
|
||||||
[CONSTANTS]
|
[CONSTANTS]
|
||||||
; # Name of this mode
|
; # Name of this mode
|
||||||
MODE = disfluent
|
MODE = disfluent
|
||||||
; # Suffixes to use after the word surface (no longer used)
|
|
||||||
MORPH = pos1 cForm cType2 pos2
|
|
||||||
; # Used to differentiate between A tag and A_num tag
|
|
||||||
JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
|
|
||||||
; # Dummy character to delineate multiline words
|
|
||||||
PLUS = +
|
|
||||||
|
|
||||||
[DECISIONS]
|
[DECISIONS]
|
||||||
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
|
|
||||||
; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
|
|
||||||
|
|
||||||
; # フィラー、感情表出系感動詞
|
; # フィラー、感情表出系感動詞
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(F ぎょっ)'
|
; # Example: '(F ぎょっ)'
|
||||||
F = 0
|
F = 0
|
||||||
; # Example: '(L (F ン))', '比べ(F えー)る'
|
|
||||||
F^ = 0
|
|
||||||
; # 言い直し、いいよどみなどによる語断片
|
; # 言い直し、いいよどみなどによる語断片
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
||||||
D = 0
|
D = 0
|
||||||
; # Example: '(L (D ドゥ)+(D ヒ))'
|
|
||||||
D^ = 0
|
|
||||||
; # 助詞、助動詞、接辞の言い直し
|
; # 助詞、助動詞、接辞の言い直し
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
||||||
D2 = 0
|
D2 = 0
|
||||||
; # Example: '(X (D2 ノ))'
|
|
||||||
D2^ = 0
|
|
||||||
; # 聞き取りや語彙の判断に自信がない場合
|
; # 聞き取りや語彙の判断に自信がない場合
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: (? 字数) の
|
; # Example: (? 字数) の
|
||||||
; # If no option: empty string is returned regardless of output
|
; # If no option: empty string is returned regardless of output
|
||||||
; # Example: '(?) で'
|
; # Example: '(?) で'
|
||||||
? = 0
|
? = 0
|
||||||
; # Example: '(D (? すー))+そう+です+よ+ね'
|
|
||||||
?^ = 0
|
|
||||||
; # タグ?で、値は複数の候補が想定される場合
|
; # タグ?で、値は複数の候補が想定される場合
|
||||||
; # 0 for main guess with matching morph info, 1 for second guess
|
; # 0 for main guess with matching morph info, 1 for second guess
|
||||||
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
||||||
?, = 0
|
?, = 0
|
||||||
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
|
|
||||||
?,^ = 0
|
|
||||||
; # 音や言葉に関するメタ的な引用
|
; # 音や言葉に関するメタ的な引用
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
||||||
M = 0
|
M = 0
|
||||||
; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
|
|
||||||
M^ = 0
|
|
||||||
; # 外国語や古語、方言など
|
; # 外国語や古語、方言など
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(O ザッツファイン)'
|
; # Example: '(O ザッツファイン)'
|
||||||
O = 0
|
O = 0
|
||||||
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
|
|
||||||
O^ = 0
|
|
||||||
; # 講演者の名前、差別語、誹謗中傷など
|
; # 講演者の名前、差別語、誹謗中傷など
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '国語研の (R ××) です'
|
; # Example: '国語研の (R ××) です'
|
||||||
R = 0
|
R = 0
|
||||||
R^ = 0
|
|
||||||
; # 非朗読対象発話(朗読における言い間違い等)
|
; # 非朗読対象発話(朗読における言い間違い等)
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(X 実際は) 実際には'
|
; # Example: '(X 実際は) 実際には'
|
||||||
X = 0
|
X = 0
|
||||||
; # Example: '(L (X (D2 ニ)))'
|
|
||||||
X^ = 0
|
|
||||||
; # アルファベットや算用数字、記号の表記
|
; # アルファベットや算用数字、記号の表記
|
||||||
; # 0 to use Japanese form, 1 to use alphabet form
|
; # 0 to use Japanese form, 1 to use alphabet form
|
||||||
; # Example: '(A シーディーアール;CD-R)'
|
; # Example: '(A シーディーアール;CD-R)'
|
||||||
A = 1
|
A = 1
|
||||||
; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
|
|
||||||
A^ = 1
|
|
||||||
; # タグAで、単語は算用数字の場合
|
; # タグAで、単語は算用数字の場合
|
||||||
; # 0 to use Japanese form, 1 to use Arabic numerals
|
; # 0 to use Japanese form, 1 to use Arabic numerals
|
||||||
; # Example: (A 二千;2000)
|
; # Example: (A 二千;2000)
|
||||||
A_num = eval:self.notag
|
A_num = 0
|
||||||
A_num^ = eval:self.notag
|
|
||||||
; # 何らかの原因で漢字表記できなくなった場合
|
; # 何らかの原因で漢字表記できなくなった場合
|
||||||
; # 0 to use broken form, 1 to use orthodox form
|
; # 0 to use broken form, 1 to use orthodox form
|
||||||
; # Example: '(K たち (F えー) ばな;橘)'
|
; # Example: '(K たち (F えー) ばな;橘)'
|
||||||
K = 1
|
K = 1
|
||||||
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
|
|
||||||
K^ = 1
|
|
||||||
; # 転訛、発音の怠けなど、一時的な発音エラー
|
; # 転訛、発音の怠けなど、一時的な発音エラー
|
||||||
; # 0 to use wrong form, 1 to use orthodox form
|
; # 0 to use wrong form, 1 to use orthodox form
|
||||||
; # Example: '(W ギーツ;ギジュツ)'
|
; # Example: '(W ギーツ;ギジュツ)'
|
||||||
W = 1
|
W = 1
|
||||||
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
|
|
||||||
W^ = 1
|
|
||||||
; # 語の読みに関する知識レベルのいい間違い
|
; # 語の読みに関する知識レベルのいい間違い
|
||||||
; # 0 to use wrong form, 1 to use orthodox form
|
; # 0 to use wrong form, 1 to use orthodox form
|
||||||
; # Example: '(B シブタイ;ジュータイ)'
|
; # Example: '(B シブタイ;ジュータイ)'
|
||||||
B = 0
|
B = 0
|
||||||
; # Example: 'データー(B カズ;スー)'
|
|
||||||
B^ = 0
|
|
||||||
; # 笑いながら発話
|
; # 笑いながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
||||||
笑 = 0
|
笑 = 0
|
||||||
; # Example: 'コク(笑 サイ+(D オン))',
|
|
||||||
笑^ = 0
|
|
||||||
; # 泣きながら発話
|
; # 泣きながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(泣 ドンナニ)'
|
; # Example: '(泣 ドンナニ)'
|
||||||
泣 = 0
|
泣 = 0
|
||||||
泣^ = 0
|
|
||||||
; # 咳をしながら発話
|
; # 咳をしながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: 'シャ(咳 リン) ノ'
|
; # Example: 'シャ(咳 リン) ノ'
|
||||||
咳 = 0
|
咳 = 0
|
||||||
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
|
|
||||||
咳^ = 0
|
|
||||||
; # ささやき声や独り言などの小さな声
|
; # ささやき声や独り言などの小さな声
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
||||||
L = 0
|
L = 0
|
||||||
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
|
|
||||||
L^ = 0
|
|
||||||
|
|
||||||
[REPLACEMENTS]
|
|
||||||
; # ボーカルフライなどで母音が同定できない場合
|
|
||||||
<FV> =
|
|
||||||
; # 「うん/うーん/ふーん」の音の特定が困難な場合
|
|
||||||
<VN> =
|
|
||||||
; # 非語彙的な母音の引き延ばし
|
|
||||||
<H> =
|
|
||||||
; # 非語彙的な子音の引き延ばし
|
|
||||||
<Q> =
|
|
||||||
; # 言語音と独立に講演者の笑いが生じている場合
|
|
||||||
<笑> =
|
|
||||||
; # 言語音と独立に講演者の咳が生じている場合
|
|
||||||
<咳> =
|
|
||||||
; # 言語音と独立に講演者の息が生じている場合
|
|
||||||
<息> =
|
|
||||||
; # 講演者の泣き声
|
|
||||||
<泣> =
|
|
||||||
; # 聴衆(司会者なども含む)の発話
|
|
||||||
<フロア発話> =
|
|
||||||
; # 聴衆の笑い
|
|
||||||
<フロア笑> =
|
|
||||||
; # 聴衆の拍手
|
|
||||||
<拍手> =
|
|
||||||
; # 講演者が発表中に用いたデモンストレーションの音声
|
|
||||||
<デモ> =
|
|
||||||
; # 学会講演に発表時間を知らせるためにならすベルの音
|
|
||||||
<ベル> =
|
|
||||||
; # 転記単位全体が再度読み直された場合
|
|
||||||
<朗読間違い> =
|
|
||||||
; # 上記以外の音で特に目立った音
|
|
||||||
<雑音> =
|
|
||||||
; # 0.2秒以上のポーズ
|
|
||||||
<P> =
|
|
||||||
; # Redacted information, for R
|
|
||||||
; # It is \x00D7 multiplication sign, not your normal 'x'
|
|
||||||
× = ×
|
|
||||||
|
|
||||||
[FIELDS]
|
|
||||||
; # Time information for segment
|
|
||||||
time = 3
|
|
||||||
; # Word surface
|
|
||||||
surface = 5
|
|
||||||
; # Word surface root form without CSJ tags
|
|
||||||
notag = 9
|
|
||||||
; # Part Of Speech
|
|
||||||
pos1 = 11
|
|
||||||
; # Conjugated Form
|
|
||||||
cForm = 12
|
|
||||||
; # Conjugation Type
|
|
||||||
cType1 = 13
|
|
||||||
; # Subcategory of POS
|
|
||||||
pos2 = 14
|
|
||||||
; # Euphonic Change / Subcategory of Conjugation Type
|
|
||||||
cType2 = 15
|
|
||||||
; # Other information
|
|
||||||
other = 16
|
|
||||||
; # Pronunciation for lexicon
|
|
||||||
pron = 10
|
|
||||||
; # Speaker ID
|
|
||||||
spk_id = 2
|
|
||||||
|
|
||||||
[KATAKANA2ROMAJI]
|
|
||||||
ア = 'a
|
|
||||||
イ = 'i
|
|
||||||
ウ = 'u
|
|
||||||
エ = 'e
|
|
||||||
オ = 'o
|
|
||||||
カ = ka
|
|
||||||
キ = ki
|
|
||||||
ク = ku
|
|
||||||
ケ = ke
|
|
||||||
コ = ko
|
|
||||||
ガ = ga
|
|
||||||
ギ = gi
|
|
||||||
グ = gu
|
|
||||||
ゲ = ge
|
|
||||||
ゴ = go
|
|
||||||
サ = sa
|
|
||||||
シ = si
|
|
||||||
ス = su
|
|
||||||
セ = se
|
|
||||||
ソ = so
|
|
||||||
ザ = za
|
|
||||||
ジ = zi
|
|
||||||
ズ = zu
|
|
||||||
ゼ = ze
|
|
||||||
ゾ = zo
|
|
||||||
タ = ta
|
|
||||||
チ = ti
|
|
||||||
ツ = tu
|
|
||||||
テ = te
|
|
||||||
ト = to
|
|
||||||
ダ = da
|
|
||||||
ヂ = di
|
|
||||||
ヅ = du
|
|
||||||
デ = de
|
|
||||||
ド = do
|
|
||||||
ナ = na
|
|
||||||
ニ = ni
|
|
||||||
ヌ = nu
|
|
||||||
ネ = ne
|
|
||||||
ノ = no
|
|
||||||
ハ = ha
|
|
||||||
ヒ = hi
|
|
||||||
フ = hu
|
|
||||||
ヘ = he
|
|
||||||
ホ = ho
|
|
||||||
バ = ba
|
|
||||||
ビ = bi
|
|
||||||
ブ = bu
|
|
||||||
ベ = be
|
|
||||||
ボ = bo
|
|
||||||
パ = pa
|
|
||||||
ピ = pi
|
|
||||||
プ = pu
|
|
||||||
ペ = pe
|
|
||||||
ポ = po
|
|
||||||
マ = ma
|
|
||||||
ミ = mi
|
|
||||||
ム = mu
|
|
||||||
メ = me
|
|
||||||
モ = mo
|
|
||||||
ヤ = ya
|
|
||||||
ユ = yu
|
|
||||||
ヨ = yo
|
|
||||||
ラ = ra
|
|
||||||
リ = ri
|
|
||||||
ル = ru
|
|
||||||
レ = re
|
|
||||||
ロ = ro
|
|
||||||
ワ = wa
|
|
||||||
ヰ = we
|
|
||||||
ヱ = wi
|
|
||||||
ヲ = wo
|
|
||||||
ン = ŋ
|
|
||||||
ッ = q
|
|
||||||
ー = -
|
|
||||||
キャ = kǐa
|
|
||||||
キュ = kǐu
|
|
||||||
キョ = kǐo
|
|
||||||
ギャ = gǐa
|
|
||||||
ギュ = gǐu
|
|
||||||
ギョ = gǐo
|
|
||||||
シャ = sǐa
|
|
||||||
シュ = sǐu
|
|
||||||
ショ = sǐo
|
|
||||||
ジャ = zǐa
|
|
||||||
ジュ = zǐu
|
|
||||||
ジョ = zǐo
|
|
||||||
チャ = tǐa
|
|
||||||
チュ = tǐu
|
|
||||||
チョ = tǐo
|
|
||||||
ヂャ = dǐa
|
|
||||||
ヂュ = dǐu
|
|
||||||
ヂョ = dǐo
|
|
||||||
ニャ = nǐa
|
|
||||||
ニュ = nǐu
|
|
||||||
ニョ = nǐo
|
|
||||||
ヒャ = hǐa
|
|
||||||
ヒュ = hǐu
|
|
||||||
ヒョ = hǐo
|
|
||||||
ビャ = bǐa
|
|
||||||
ビュ = bǐu
|
|
||||||
ビョ = bǐo
|
|
||||||
ピャ = pǐa
|
|
||||||
ピュ = pǐu
|
|
||||||
ピョ = pǐo
|
|
||||||
ミャ = mǐa
|
|
||||||
ミュ = mǐu
|
|
||||||
ミョ = mǐo
|
|
||||||
リャ = rǐa
|
|
||||||
リュ = rǐu
|
|
||||||
リョ = rǐo
|
|
||||||
ァ = a
|
|
||||||
ィ = i
|
|
||||||
ゥ = u
|
|
||||||
ェ = e
|
|
||||||
ォ = o
|
|
||||||
ヮ = ʍ
|
|
||||||
ヴ = vu
|
|
||||||
ャ = ǐa
|
|
||||||
ュ = ǐu
|
|
||||||
ョ = ǐo
|
|
||||||
|
@ -1,320 +1,79 @@
|
|||||||
; # This section is ignored if this file is not supplied as the first config file to
|
|
||||||
; # lhotse prepare csj
|
|
||||||
[SEGMENTS]
|
|
||||||
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
|
|
||||||
gap = 0.5
|
|
||||||
; # Maximum length of segment (s).
|
|
||||||
maxlen = 10
|
|
||||||
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
|
|
||||||
minlen = 0.02
|
|
||||||
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
|
|
||||||
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
|
|
||||||
; # If you intend to use a multicharacter string for gap_sym, remember to register the
|
|
||||||
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
|
|
||||||
gap_sym =
|
|
||||||
|
|
||||||
[CONSTANTS]
|
[CONSTANTS]
|
||||||
; # Name of this mode
|
; # Name of this mode
|
||||||
MODE = fluent
|
MODE = fluent
|
||||||
; # Suffixes to use after the word surface (no longer used)
|
|
||||||
MORPH = pos1 cForm cType2 pos2
|
|
||||||
; # Used to differentiate between A tag and A_num tag
|
|
||||||
JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
|
|
||||||
; # Dummy character to delineate multiline words
|
|
||||||
PLUS = +
|
|
||||||
|
|
||||||
[DECISIONS]
|
[DECISIONS]
|
||||||
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
|
|
||||||
; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
|
|
||||||
|
|
||||||
; # フィラー、感情表出系感動詞
|
; # フィラー、感情表出系感動詞
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(F ぎょっ)'
|
; # Example: '(F ぎょっ)'
|
||||||
F = 1
|
F = 1
|
||||||
; # Example: '(L (F ン))', '比べ(F えー)る'
|
|
||||||
F^ = 1
|
|
||||||
; # 言い直し、いいよどみなどによる語断片
|
; # 言い直し、いいよどみなどによる語断片
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
||||||
D = 1
|
D = 1
|
||||||
; # Example: '(L (D ドゥ)+(D ヒ))'
|
|
||||||
D^ = 1
|
|
||||||
; # 助詞、助動詞、接辞の言い直し
|
; # 助詞、助動詞、接辞の言い直し
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
||||||
D2 = 1
|
D2 = 1
|
||||||
; # Example: '(X (D2 ノ))'
|
|
||||||
D2^ = 1
|
|
||||||
; # 聞き取りや語彙の判断に自信がない場合
|
; # 聞き取りや語彙の判断に自信がない場合
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: (? 字数) の
|
; # Example: (? 字数) の
|
||||||
; # If no option: empty string is returned regardless of output
|
; # If no option: empty string is returned regardless of output
|
||||||
; # Example: '(?) で'
|
; # Example: '(?) で'
|
||||||
? = 0
|
? = 0
|
||||||
; # Example: '(D (? すー))+そう+です+よ+ね'
|
|
||||||
?^ = 0
|
|
||||||
; # タグ?で、値は複数の候補が想定される場合
|
; # タグ?で、値は複数の候補が想定される場合
|
||||||
; # 0 for main guess with matching morph info, 1 for second guess
|
; # 0 for main guess with matching morph info, 1 for second guess
|
||||||
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
||||||
?, = 0
|
?, = 0
|
||||||
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
|
|
||||||
?,^ = 0
|
|
||||||
; # 音や言葉に関するメタ的な引用
|
; # 音や言葉に関するメタ的な引用
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
||||||
M = 0
|
M = 0
|
||||||
; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
|
|
||||||
M^ = 0
|
|
||||||
; # 外国語や古語、方言など
|
; # 外国語や古語、方言など
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(O ザッツファイン)'
|
; # Example: '(O ザッツファイン)'
|
||||||
O = 0
|
O = 0
|
||||||
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
|
|
||||||
O^ = 0
|
|
||||||
; # 講演者の名前、差別語、誹謗中傷など
|
; # 講演者の名前、差別語、誹謗中傷など
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '国語研の (R ××) です'
|
; # Example: '国語研の (R ××) です'
|
||||||
R = 0
|
R = 0
|
||||||
R^ = 0
|
|
||||||
; # 非朗読対象発話(朗読における言い間違い等)
|
; # 非朗読対象発話(朗読における言い間違い等)
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(X 実際は) 実際には'
|
; # Example: '(X 実際は) 実際には'
|
||||||
X = 0
|
X = 0
|
||||||
; # Example: '(L (X (D2 ニ)))'
|
|
||||||
X^ = 0
|
|
||||||
; # アルファベットや算用数字、記号の表記
|
; # アルファベットや算用数字、記号の表記
|
||||||
; # 0 to use Japanese form, 1 to use alphabet form
|
; # 0 to use Japanese form, 1 to use alphabet form
|
||||||
; # Example: '(A シーディーアール;CD-R)'
|
; # Example: '(A シーディーアール;CD-R)'
|
||||||
A = 1
|
A = 1
|
||||||
; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
|
|
||||||
A^ = 1
|
|
||||||
; # タグAで、単語は算用数字の場合
|
; # タグAで、単語は算用数字の場合
|
||||||
; # 0 to use Japanese form, 1 to use Arabic numerals
|
; # 0 to use Japanese form, 1 to use Arabic numerals
|
||||||
; # Example: (A 二千;2000)
|
; # Example: (A 二千;2000)
|
||||||
A_num = eval:self.notag
|
A_num = 0
|
||||||
A_num^ = eval:self.notag
|
|
||||||
; # 何らかの原因で漢字表記できなくなった場合
|
; # 何らかの原因で漢字表記できなくなった場合
|
||||||
; # 0 to use broken form, 1 to use orthodox form
|
; # 0 to use broken form, 1 to use orthodox form
|
||||||
; # Example: '(K たち (F えー) ばな;橘)'
|
; # Example: '(K たち (F えー) ばな;橘)'
|
||||||
K = 1
|
K = 1
|
||||||
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
|
|
||||||
K^ = 1
|
|
||||||
; # 転訛、発音の怠けなど、一時的な発音エラー
|
; # 転訛、発音の怠けなど、一時的な発音エラー
|
||||||
; # 0 to use wrong form, 1 to use orthodox form
|
; # 0 to use wrong form, 1 to use orthodox form
|
||||||
; # Example: '(W ギーツ;ギジュツ)'
|
; # Example: '(W ギーツ;ギジュツ)'
|
||||||
W = 1
|
W = 1
|
||||||
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
|
|
||||||
W^ = 1
|
|
||||||
; # 語の読みに関する知識レベルのいい間違い
|
; # 語の読みに関する知識レベルのいい間違い
|
||||||
; # 0 to use wrong form, 1 to use orthodox form
|
; # 0 to use wrong form, 1 to use orthodox form
|
||||||
; # Example: '(B シブタイ;ジュータイ)'
|
; # Example: '(B シブタイ;ジュータイ)'
|
||||||
B = 0
|
B = 0
|
||||||
; # Example: 'データー(B カズ;スー)'
|
|
||||||
B^ = 0
|
|
||||||
; # 笑いながら発話
|
; # 笑いながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
||||||
笑 = 0
|
笑 = 0
|
||||||
; # Example: 'コク(笑 サイ+(D オン))',
|
|
||||||
笑^ = 0
|
|
||||||
; # 泣きながら発話
|
; # 泣きながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(泣 ドンナニ)'
|
; # Example: '(泣 ドンナニ)'
|
||||||
泣 = 0
|
泣 = 0
|
||||||
泣^ = 0
|
|
||||||
; # 咳をしながら発話
|
; # 咳をしながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: 'シャ(咳 リン) ノ'
|
; # Example: 'シャ(咳 リン) ノ'
|
||||||
咳 = 0
|
咳 = 0
|
||||||
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
|
|
||||||
咳^ = 0
|
|
||||||
; # ささやき声や独り言などの小さな声
|
; # ささやき声や独り言などの小さな声
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
||||||
L = 0
|
L = 0
|
||||||
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
|
|
||||||
L^ = 0
|
|
||||||
|
|
||||||
[REPLACEMENTS]
|
|
||||||
; # ボーカルフライなどで母音が同定できない場合
|
|
||||||
<FV> =
|
|
||||||
; # 「うん/うーん/ふーん」の音の特定が困難な場合
|
|
||||||
<VN> =
|
|
||||||
; # 非語彙的な母音の引き延ばし
|
|
||||||
<H> =
|
|
||||||
; # 非語彙的な子音の引き延ばし
|
|
||||||
<Q> =
|
|
||||||
; # 言語音と独立に講演者の笑いが生じている場合
|
|
||||||
<笑> =
|
|
||||||
; # 言語音と独立に講演者の咳が生じている場合
|
|
||||||
<咳> =
|
|
||||||
; # 言語音と独立に講演者の息が生じている場合
|
|
||||||
<息> =
|
|
||||||
; # 講演者の泣き声
|
|
||||||
<泣> =
|
|
||||||
; # 聴衆(司会者なども含む)の発話
|
|
||||||
<フロア発話> =
|
|
||||||
; # 聴衆の笑い
|
|
||||||
<フロア笑> =
|
|
||||||
; # 聴衆の拍手
|
|
||||||
<拍手> =
|
|
||||||
; # 講演者が発表中に用いたデモンストレーションの音声
|
|
||||||
<デモ> =
|
|
||||||
; # 学会講演に発表時間を知らせるためにならすベルの音
|
|
||||||
<ベル> =
|
|
||||||
; # 転記単位全体が再度読み直された場合
|
|
||||||
<朗読間違い> =
|
|
||||||
; # 上記以外の音で特に目立った音
|
|
||||||
<雑音> =
|
|
||||||
; # 0.2秒以上のポーズ
|
|
||||||
<P> =
|
|
||||||
; # Redacted information, for R
|
|
||||||
; # It is \x00D7 multiplication sign, not your normal 'x'
|
|
||||||
× = ×
|
|
||||||
|
|
||||||
[FIELDS]
|
|
||||||
; # Time information for segment
|
|
||||||
time = 3
|
|
||||||
; # Word surface
|
|
||||||
surface = 5
|
|
||||||
; # Word surface root form without CSJ tags
|
|
||||||
notag = 9
|
|
||||||
; # Part Of Speech
|
|
||||||
pos1 = 11
|
|
||||||
; # Conjugated Form
|
|
||||||
cForm = 12
|
|
||||||
; # Conjugation Type
|
|
||||||
cType1 = 13
|
|
||||||
; # Subcategory of POS
|
|
||||||
pos2 = 14
|
|
||||||
; # Euphonic Change / Subcategory of Conjugation Type
|
|
||||||
cType2 = 15
|
|
||||||
; # Other information
|
|
||||||
other = 16
|
|
||||||
; # Pronunciation for lexicon
|
|
||||||
pron = 10
|
|
||||||
; # Speaker ID
|
|
||||||
spk_id = 2
|
|
||||||
|
|
||||||
[KATAKANA2ROMAJI]
|
|
||||||
ア = 'a
|
|
||||||
イ = 'i
|
|
||||||
ウ = 'u
|
|
||||||
エ = 'e
|
|
||||||
オ = 'o
|
|
||||||
カ = ka
|
|
||||||
キ = ki
|
|
||||||
ク = ku
|
|
||||||
ケ = ke
|
|
||||||
コ = ko
|
|
||||||
ガ = ga
|
|
||||||
ギ = gi
|
|
||||||
グ = gu
|
|
||||||
ゲ = ge
|
|
||||||
ゴ = go
|
|
||||||
サ = sa
|
|
||||||
シ = si
|
|
||||||
ス = su
|
|
||||||
セ = se
|
|
||||||
ソ = so
|
|
||||||
ザ = za
|
|
||||||
ジ = zi
|
|
||||||
ズ = zu
|
|
||||||
ゼ = ze
|
|
||||||
ゾ = zo
|
|
||||||
タ = ta
|
|
||||||
チ = ti
|
|
||||||
ツ = tu
|
|
||||||
テ = te
|
|
||||||
ト = to
|
|
||||||
ダ = da
|
|
||||||
ヂ = di
|
|
||||||
ヅ = du
|
|
||||||
デ = de
|
|
||||||
ド = do
|
|
||||||
ナ = na
|
|
||||||
ニ = ni
|
|
||||||
ヌ = nu
|
|
||||||
ネ = ne
|
|
||||||
ノ = no
|
|
||||||
ハ = ha
|
|
||||||
ヒ = hi
|
|
||||||
フ = hu
|
|
||||||
ヘ = he
|
|
||||||
ホ = ho
|
|
||||||
バ = ba
|
|
||||||
ビ = bi
|
|
||||||
ブ = bu
|
|
||||||
ベ = be
|
|
||||||
ボ = bo
|
|
||||||
パ = pa
|
|
||||||
ピ = pi
|
|
||||||
プ = pu
|
|
||||||
ペ = pe
|
|
||||||
ポ = po
|
|
||||||
マ = ma
|
|
||||||
ミ = mi
|
|
||||||
ム = mu
|
|
||||||
メ = me
|
|
||||||
モ = mo
|
|
||||||
ヤ = ya
|
|
||||||
ユ = yu
|
|
||||||
ヨ = yo
|
|
||||||
ラ = ra
|
|
||||||
リ = ri
|
|
||||||
ル = ru
|
|
||||||
レ = re
|
|
||||||
ロ = ro
|
|
||||||
ワ = wa
|
|
||||||
ヰ = we
|
|
||||||
ヱ = wi
|
|
||||||
ヲ = wo
|
|
||||||
ン = ŋ
|
|
||||||
ッ = q
|
|
||||||
ー = -
|
|
||||||
キャ = kǐa
|
|
||||||
キュ = kǐu
|
|
||||||
キョ = kǐo
|
|
||||||
ギャ = gǐa
|
|
||||||
ギュ = gǐu
|
|
||||||
ギョ = gǐo
|
|
||||||
シャ = sǐa
|
|
||||||
シュ = sǐu
|
|
||||||
ショ = sǐo
|
|
||||||
ジャ = zǐa
|
|
||||||
ジュ = zǐu
|
|
||||||
ジョ = zǐo
|
|
||||||
チャ = tǐa
|
|
||||||
チュ = tǐu
|
|
||||||
チョ = tǐo
|
|
||||||
ヂャ = dǐa
|
|
||||||
ヂュ = dǐu
|
|
||||||
ヂョ = dǐo
|
|
||||||
ニャ = nǐa
|
|
||||||
ニュ = nǐu
|
|
||||||
ニョ = nǐo
|
|
||||||
ヒャ = hǐa
|
|
||||||
ヒュ = hǐu
|
|
||||||
ヒョ = hǐo
|
|
||||||
ビャ = bǐa
|
|
||||||
ビュ = bǐu
|
|
||||||
ビョ = bǐo
|
|
||||||
ピャ = pǐa
|
|
||||||
ピュ = pǐu
|
|
||||||
ピョ = pǐo
|
|
||||||
ミャ = mǐa
|
|
||||||
ミュ = mǐu
|
|
||||||
ミョ = mǐo
|
|
||||||
リャ = rǐa
|
|
||||||
リュ = rǐu
|
|
||||||
リョ = rǐo
|
|
||||||
ァ = a
|
|
||||||
ィ = i
|
|
||||||
ゥ = u
|
|
||||||
ェ = e
|
|
||||||
ォ = o
|
|
||||||
ヮ = ʍ
|
|
||||||
ヴ = vu
|
|
||||||
ャ = ǐa
|
|
||||||
ュ = ǐu
|
|
||||||
ョ = ǐo
|
|
||||||
|
@ -1,320 +1,79 @@
|
|||||||
; # This section is ignored if this file is not supplied as the first config file to
|
|
||||||
; # lhotse prepare csj
|
|
||||||
[SEGMENTS]
|
|
||||||
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
|
|
||||||
gap = 0.5
|
|
||||||
; # Maximum length of segment (s).
|
|
||||||
maxlen = 10
|
|
||||||
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
|
|
||||||
minlen = 0.02
|
|
||||||
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
|
|
||||||
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
|
|
||||||
; # If you intend to use a multicharacter string for gap_sym, remember to register the
|
|
||||||
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
|
|
||||||
gap_sym =
|
|
||||||
|
|
||||||
[CONSTANTS]
|
[CONSTANTS]
|
||||||
; # Name of this mode
|
; # Name of this mode
|
||||||
MODE = number
|
MODE = number
|
||||||
; # Suffixes to use after the word surface (no longer used)
|
|
||||||
MORPH = pos1 cForm cType2 pos2
|
|
||||||
; # Used to differentiate between A tag and A_num tag
|
|
||||||
JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
|
|
||||||
; # Dummy character to delineate multiline words
|
|
||||||
PLUS = +
|
|
||||||
|
|
||||||
[DECISIONS]
|
[DECISIONS]
|
||||||
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
|
|
||||||
; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
|
|
||||||
|
|
||||||
; # フィラー、感情表出系感動詞
|
; # フィラー、感情表出系感動詞
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(F ぎょっ)'
|
; # Example: '(F ぎょっ)'
|
||||||
F = 1
|
F = 1
|
||||||
; # Example: '(L (F ン))', '比べ(F えー)る'
|
|
||||||
F^ = 1
|
|
||||||
; # 言い直し、いいよどみなどによる語断片
|
; # 言い直し、いいよどみなどによる語断片
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
||||||
D = 1
|
D = 1
|
||||||
; # Example: '(L (D ドゥ)+(D ヒ))'
|
|
||||||
D^ = 1
|
|
||||||
; # 助詞、助動詞、接辞の言い直し
|
; # 助詞、助動詞、接辞の言い直し
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
||||||
D2 = 1
|
D2 = 1
|
||||||
; # Example: '(X (D2 ノ))'
|
|
||||||
D2^ = 1
|
|
||||||
; # 聞き取りや語彙の判断に自信がない場合
|
; # 聞き取りや語彙の判断に自信がない場合
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: (? 字数) の
|
; # Example: (? 字数) の
|
||||||
; # If no option: empty string is returned regardless of output
|
; # If no option: empty string is returned regardless of output
|
||||||
; # Example: '(?) で'
|
; # Example: '(?) で'
|
||||||
? = 0
|
? = 0
|
||||||
; # Example: '(D (? すー))+そう+です+よ+ね'
|
|
||||||
?^ = 0
|
|
||||||
; # タグ?で、値は複数の候補が想定される場合
|
; # タグ?で、値は複数の候補が想定される場合
|
||||||
; # 0 for main guess with matching morph info, 1 for second guess
|
; # 0 for main guess with matching morph info, 1 for second guess
|
||||||
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
||||||
?, = 0
|
?, = 0
|
||||||
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
|
|
||||||
?,^ = 0
|
|
||||||
; # 音や言葉に関するメタ的な引用
|
; # 音や言葉に関するメタ的な引用
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
||||||
M = 0
|
M = 0
|
||||||
; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
|
|
||||||
M^ = 0
|
|
||||||
; # 外国語や古語、方言など
|
; # 外国語や古語、方言など
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(O ザッツファイン)'
|
; # Example: '(O ザッツファイン)'
|
||||||
O = 0
|
O = 0
|
||||||
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
|
|
||||||
O^ = 0
|
|
||||||
; # 講演者の名前、差別語、誹謗中傷など
|
; # 講演者の名前、差別語、誹謗中傷など
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '国語研の (R ××) です'
|
; # Example: '国語研の (R ××) です'
|
||||||
R = 0
|
R = 0
|
||||||
R^ = 0
|
|
||||||
; # 非朗読対象発話(朗読における言い間違い等)
|
; # 非朗読対象発話(朗読における言い間違い等)
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(X 実際は) 実際には'
|
; # Example: '(X 実際は) 実際には'
|
||||||
X = 0
|
X = 0
|
||||||
; # Example: '(L (X (D2 ニ)))'
|
|
||||||
X^ = 0
|
|
||||||
; # アルファベットや算用数字、記号の表記
|
; # アルファベットや算用数字、記号の表記
|
||||||
; # 0 to use Japanese form, 1 to use alphabet form
|
; # 0 to use Japanese form, 1 to use alphabet form
|
||||||
; # Example: '(A シーディーアール;CD-R)'
|
; # Example: '(A シーディーアール;CD-R)'
|
||||||
A = 1
|
A = 1
|
||||||
; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
|
|
||||||
A^ = 1
|
|
||||||
; # タグAで、単語は算用数字の場合
|
; # タグAで、単語は算用数字の場合
|
||||||
; # 0 to use Japanese form, 1 to use Arabic numerals
|
; # 0 to use Japanese form, 1 to use Arabic numerals
|
||||||
; # Example: (A 二千;2000)
|
; # Example: (A 二千;2000)
|
||||||
A_num = 1
|
A_num = 1
|
||||||
A_num^ = 1
|
|
||||||
; # 何らかの原因で漢字表記できなくなった場合
|
; # 何らかの原因で漢字表記できなくなった場合
|
||||||
; # 0 to use broken form, 1 to use orthodox form
|
; # 0 to use broken form, 1 to use orthodox form
|
||||||
; # Example: '(K たち (F えー) ばな;橘)'
|
; # Example: '(K たち (F えー) ばな;橘)'
|
||||||
K = 1
|
K = 1
|
||||||
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
|
|
||||||
K^ = 1
|
|
||||||
; # 転訛、発音の怠けなど、一時的な発音エラー
|
; # 転訛、発音の怠けなど、一時的な発音エラー
|
||||||
; # 0 to use wrong form, 1 to use orthodox form
|
; # 0 to use wrong form, 1 to use orthodox form
|
||||||
; # Example: '(W ギーツ;ギジュツ)'
|
; # Example: '(W ギーツ;ギジュツ)'
|
||||||
W = 1
|
W = 1
|
||||||
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
|
|
||||||
W^ = 1
|
|
||||||
; # 語の読みに関する知識レベルのいい間違い
|
; # 語の読みに関する知識レベルのいい間違い
|
||||||
; # 0 to use wrong form, 1 to use orthodox form
|
; # 0 to use wrong form, 1 to use orthodox form
|
||||||
; # Example: '(B シブタイ;ジュータイ)'
|
; # Example: '(B シブタイ;ジュータイ)'
|
||||||
B = 0
|
B = 0
|
||||||
; # Example: 'データー(B カズ;スー)'
|
|
||||||
B^ = 0
|
|
||||||
; # 笑いながら発話
|
; # 笑いながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
||||||
笑 = 0
|
笑 = 0
|
||||||
; # Example: 'コク(笑 サイ+(D オン))',
|
|
||||||
笑^ = 0
|
|
||||||
; # 泣きながら発話
|
; # 泣きながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(泣 ドンナニ)'
|
; # Example: '(泣 ドンナニ)'
|
||||||
泣 = 0
|
泣 = 0
|
||||||
泣^ = 0
|
|
||||||
; # 咳をしながら発話
|
; # 咳をしながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: 'シャ(咳 リン) ノ'
|
; # Example: 'シャ(咳 リン) ノ'
|
||||||
咳 = 0
|
咳 = 0
|
||||||
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
|
|
||||||
咳^ = 0
|
|
||||||
; # ささやき声や独り言などの小さな声
|
; # ささやき声や独り言などの小さな声
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
||||||
L = 0
|
L = 0
|
||||||
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
|
|
||||||
L^ = 0
|
|
||||||
|
|
||||||
[REPLACEMENTS]
|
|
||||||
; # ボーカルフライなどで母音が同定できない場合
|
|
||||||
<FV> =
|
|
||||||
; # 「うん/うーん/ふーん」の音の特定が困難な場合
|
|
||||||
<VN> =
|
|
||||||
; # 非語彙的な母音の引き延ばし
|
|
||||||
<H> =
|
|
||||||
; # 非語彙的な子音の引き延ばし
|
|
||||||
<Q> =
|
|
||||||
; # 言語音と独立に講演者の笑いが生じている場合
|
|
||||||
<笑> =
|
|
||||||
; # 言語音と独立に講演者の咳が生じている場合
|
|
||||||
<咳> =
|
|
||||||
; # 言語音と独立に講演者の息が生じている場合
|
|
||||||
<息> =
|
|
||||||
; # 講演者の泣き声
|
|
||||||
<泣> =
|
|
||||||
; # 聴衆(司会者なども含む)の発話
|
|
||||||
<フロア発話> =
|
|
||||||
; # 聴衆の笑い
|
|
||||||
<フロア笑> =
|
|
||||||
; # 聴衆の拍手
|
|
||||||
<拍手> =
|
|
||||||
; # 講演者が発表中に用いたデモンストレーションの音声
|
|
||||||
<デモ> =
|
|
||||||
; # 学会講演に発表時間を知らせるためにならすベルの音
|
|
||||||
<ベル> =
|
|
||||||
; # 転記単位全体が再度読み直された場合
|
|
||||||
<朗読間違い> =
|
|
||||||
; # 上記以外の音で特に目立った音
|
|
||||||
<雑音> =
|
|
||||||
; # 0.2秒以上のポーズ
|
|
||||||
<P> =
|
|
||||||
; # Redacted information, for R
|
|
||||||
; # It is \x00D7 multiplication sign, not your normal 'x'
|
|
||||||
× = ×
|
|
||||||
|
|
||||||
[FIELDS]
|
|
||||||
; # Time information for segment
|
|
||||||
time = 3
|
|
||||||
; # Word surface
|
|
||||||
surface = 5
|
|
||||||
; # Word surface root form without CSJ tags
|
|
||||||
notag = 9
|
|
||||||
; # Part Of Speech
|
|
||||||
pos1 = 11
|
|
||||||
; # Conjugated Form
|
|
||||||
cForm = 12
|
|
||||||
; # Conjugation Type
|
|
||||||
cType1 = 13
|
|
||||||
; # Subcategory of POS
|
|
||||||
pos2 = 14
|
|
||||||
; # Euphonic Change / Subcategory of Conjugation Type
|
|
||||||
cType2 = 15
|
|
||||||
; # Other information
|
|
||||||
other = 16
|
|
||||||
; # Pronunciation for lexicon
|
|
||||||
pron = 10
|
|
||||||
; # Speaker ID
|
|
||||||
spk_id = 2
|
|
||||||
|
|
||||||
[KATAKANA2ROMAJI]
|
|
||||||
ア = 'a
|
|
||||||
イ = 'i
|
|
||||||
ウ = 'u
|
|
||||||
エ = 'e
|
|
||||||
オ = 'o
|
|
||||||
カ = ka
|
|
||||||
キ = ki
|
|
||||||
ク = ku
|
|
||||||
ケ = ke
|
|
||||||
コ = ko
|
|
||||||
ガ = ga
|
|
||||||
ギ = gi
|
|
||||||
グ = gu
|
|
||||||
ゲ = ge
|
|
||||||
ゴ = go
|
|
||||||
サ = sa
|
|
||||||
シ = si
|
|
||||||
ス = su
|
|
||||||
セ = se
|
|
||||||
ソ = so
|
|
||||||
ザ = za
|
|
||||||
ジ = zi
|
|
||||||
ズ = zu
|
|
||||||
ゼ = ze
|
|
||||||
ゾ = zo
|
|
||||||
タ = ta
|
|
||||||
チ = ti
|
|
||||||
ツ = tu
|
|
||||||
テ = te
|
|
||||||
ト = to
|
|
||||||
ダ = da
|
|
||||||
ヂ = di
|
|
||||||
ヅ = du
|
|
||||||
デ = de
|
|
||||||
ド = do
|
|
||||||
ナ = na
|
|
||||||
ニ = ni
|
|
||||||
ヌ = nu
|
|
||||||
ネ = ne
|
|
||||||
ノ = no
|
|
||||||
ハ = ha
|
|
||||||
ヒ = hi
|
|
||||||
フ = hu
|
|
||||||
ヘ = he
|
|
||||||
ホ = ho
|
|
||||||
バ = ba
|
|
||||||
ビ = bi
|
|
||||||
ブ = bu
|
|
||||||
ベ = be
|
|
||||||
ボ = bo
|
|
||||||
パ = pa
|
|
||||||
ピ = pi
|
|
||||||
プ = pu
|
|
||||||
ペ = pe
|
|
||||||
ポ = po
|
|
||||||
マ = ma
|
|
||||||
ミ = mi
|
|
||||||
ム = mu
|
|
||||||
メ = me
|
|
||||||
モ = mo
|
|
||||||
ヤ = ya
|
|
||||||
ユ = yu
|
|
||||||
ヨ = yo
|
|
||||||
ラ = ra
|
|
||||||
リ = ri
|
|
||||||
ル = ru
|
|
||||||
レ = re
|
|
||||||
ロ = ro
|
|
||||||
ワ = wa
|
|
||||||
ヰ = we
|
|
||||||
ヱ = wi
|
|
||||||
ヲ = wo
|
|
||||||
ン = ŋ
|
|
||||||
ッ = q
|
|
||||||
ー = -
|
|
||||||
キャ = kǐa
|
|
||||||
キュ = kǐu
|
|
||||||
キョ = kǐo
|
|
||||||
ギャ = gǐa
|
|
||||||
ギュ = gǐu
|
|
||||||
ギョ = gǐo
|
|
||||||
シャ = sǐa
|
|
||||||
シュ = sǐu
|
|
||||||
ショ = sǐo
|
|
||||||
ジャ = zǐa
|
|
||||||
ジュ = zǐu
|
|
||||||
ジョ = zǐo
|
|
||||||
チャ = tǐa
|
|
||||||
チュ = tǐu
|
|
||||||
チョ = tǐo
|
|
||||||
ヂャ = dǐa
|
|
||||||
ヂュ = dǐu
|
|
||||||
ヂョ = dǐo
|
|
||||||
ニャ = nǐa
|
|
||||||
ニュ = nǐu
|
|
||||||
ニョ = nǐo
|
|
||||||
ヒャ = hǐa
|
|
||||||
ヒュ = hǐu
|
|
||||||
ヒョ = hǐo
|
|
||||||
ビャ = bǐa
|
|
||||||
ビュ = bǐu
|
|
||||||
ビョ = bǐo
|
|
||||||
ピャ = pǐa
|
|
||||||
ピュ = pǐu
|
|
||||||
ピョ = pǐo
|
|
||||||
ミャ = mǐa
|
|
||||||
ミュ = mǐu
|
|
||||||
ミョ = mǐo
|
|
||||||
リャ = rǐa
|
|
||||||
リュ = rǐu
|
|
||||||
リョ = rǐo
|
|
||||||
ァ = a
|
|
||||||
ィ = i
|
|
||||||
ゥ = u
|
|
||||||
ェ = e
|
|
||||||
ォ = o
|
|
||||||
ヮ = ʍ
|
|
||||||
ヴ = vu
|
|
||||||
ャ = ǐa
|
|
||||||
ュ = ǐu
|
|
||||||
ョ = ǐo
|
|
||||||
|
@ -1,321 +1,80 @@
|
|||||||
; # This section is ignored if this file is not supplied as the first config file to
|
|
||||||
; # lhotse prepare csj
|
|
||||||
[SEGMENTS]
|
|
||||||
; # Allowed period of nonverbal noise. If exceeded, a new segment is created.
|
|
||||||
gap = 0.5
|
|
||||||
; # Maximum length of segment (s).
|
|
||||||
maxlen = 10
|
|
||||||
; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently.
|
|
||||||
minlen = 0.02
|
|
||||||
; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`.
|
|
||||||
; # Pass an empty string to avoid adding any symbol. It was "<sp>" in kaldi.
|
|
||||||
; # If you intend to use a multicharacter string for gap_sym, remember to register the
|
|
||||||
; # multicharacter string as part of userdef-string in prepare_lang_char.py.
|
|
||||||
gap_sym =
|
|
||||||
|
|
||||||
[CONSTANTS]
|
[CONSTANTS]
|
||||||
; # Name of this mode
|
; # Name of this mode
|
||||||
; # See https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf
|
; # From https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf
|
||||||
MODE = symbol
|
MODE = symbol
|
||||||
; # Suffixes to use after the word surface (no longer used)
|
|
||||||
MORPH = pos1 cForm cType2 pos2
|
|
||||||
; # Used to differentiate between A tag and A_num tag
|
|
||||||
JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 .
|
|
||||||
; # Dummy character to delineate multiline words
|
|
||||||
PLUS = +
|
|
||||||
|
|
||||||
[DECISIONS]
|
[DECISIONS]
|
||||||
; # TAG+'^'とは、タグが一つの転記単位に独立していない場合
|
|
||||||
; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries
|
|
||||||
|
|
||||||
; # フィラー、感情表出系感動詞
|
; # フィラー、感情表出系感動詞
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(F ぎょっ)'
|
; # Example: '(F ぎょっ)'
|
||||||
F = #
|
F = "#", ["F"]
|
||||||
; # Example: '(L (F ン))', '比べ(F えー)る'
|
|
||||||
F^ = #
|
|
||||||
; # 言い直し、いいよどみなどによる語断片
|
; # 言い直し、いいよどみなどによる語断片
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
; # Example: '(D だ)(D だいが) 大学の学部の会議'
|
||||||
D = @
|
D = "@", ["D"]
|
||||||
; # Example: '(L (D ドゥ)+(D ヒ))'
|
|
||||||
D^ = @
|
|
||||||
; # 助詞、助動詞、接辞の言い直し
|
; # 助詞、助動詞、接辞の言い直し
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか'
|
||||||
D2 = @
|
D2 = "@", ["D2"]
|
||||||
; # Example: '(X (D2 ノ))'
|
|
||||||
D2^ = @
|
|
||||||
; # 聞き取りや語彙の判断に自信がない場合
|
; # 聞き取りや語彙の判断に自信がない場合
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: (? 字数) の
|
; # Example: (? 字数) の
|
||||||
; # If no option: empty string is returned regardless of output
|
; # If no option: empty string is returned regardless of output
|
||||||
; # Example: '(?) で'
|
; # Example: '(?) で'
|
||||||
? = 0
|
? = 0
|
||||||
; # Example: '(D (? すー))+そう+です+よ+ね'
|
|
||||||
?^ = 0
|
|
||||||
; # タグ?で、値は複数の候補が想定される場合
|
; # タグ?で、値は複数の候補が想定される場合
|
||||||
; # 0 for main guess with matching morph info, 1 for second guess
|
; # 0 for main guess with matching morph info, 1 for second guess
|
||||||
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)'
|
||||||
?, = 0
|
?, = 0
|
||||||
; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))'
|
|
||||||
?,^ = 0
|
|
||||||
; # 音や言葉に関するメタ的な引用
|
; # 音や言葉に関するメタ的な引用
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)'
|
||||||
M = 0
|
M = 0
|
||||||
; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))'
|
|
||||||
M^ = 0
|
|
||||||
; # 外国語や古語、方言など
|
; # 外国語や古語、方言など
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(O ザッツファイン)'
|
; # Example: '(O ザッツファイン)'
|
||||||
O = 0
|
O = 0
|
||||||
; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))'
|
|
||||||
O^ = 0
|
|
||||||
; # 講演者の名前、差別語、誹謗中傷など
|
; # 講演者の名前、差別語、誹謗中傷など
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '国語研の (R ××) です'
|
; # Example: '国語研の (R ××) です'
|
||||||
R = 0
|
R = 0
|
||||||
R^ = 0
|
|
||||||
; # 非朗読対象発話(朗読における言い間違い等)
|
; # 非朗読対象発話(朗読における言い間違い等)
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(X 実際は) 実際には'
|
; # Example: '(X 実際は) 実際には'
|
||||||
X = 0
|
X = 0
|
||||||
; # Example: '(L (X (D2 ニ)))'
|
|
||||||
X^ = 0
|
|
||||||
; # アルファベットや算用数字、記号の表記
|
; # アルファベットや算用数字、記号の表記
|
||||||
; # 0 to use Japanese form, 1 to use alphabet form
|
; # 0 to use Japanese form, 1 to use alphabet form
|
||||||
; # Example: '(A シーディーアール;CD-R)'
|
; # Example: '(A シーディーアール;CD-R)'
|
||||||
A = 1
|
A = 1
|
||||||
; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)')
|
|
||||||
A^ = 1
|
|
||||||
; # タグAで、単語は算用数字の場合
|
; # タグAで、単語は算用数字の場合
|
||||||
; # 0 to use Japanese form, 1 to use Arabic numerals
|
; # 0 to use Japanese form, 1 to use Arabic numerals
|
||||||
; # Example: (A 二千;2000)
|
; # Example: (A 二千;2000)
|
||||||
A_num = eval:self.notag
|
A_num = 1
|
||||||
A_num^ = eval:self.notag
|
|
||||||
; # 何らかの原因で漢字表記できなくなった場合
|
; # 何らかの原因で漢字表記できなくなった場合
|
||||||
; # 0 to use broken form, 1 to use orthodox form
|
; # 0 to use broken form, 1 to use orthodox form
|
||||||
; # Example: '(K たち (F えー) ばな;橘)'
|
; # Example: '(K たち (F えー) ばな;橘)'
|
||||||
K = 1
|
K = 1
|
||||||
; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)'
|
|
||||||
K^ = 1
|
|
||||||
; # 転訛、発音の怠けなど、一時的な発音エラー
|
; # 転訛、発音の怠けなど、一時的な発音エラー
|
||||||
; # 0 to use wrong form, 1 to use orthodox form
|
; # 0 to use wrong form, 1 to use orthodox form
|
||||||
; # Example: '(W ギーツ;ギジュツ)'
|
; # Example: '(W ギーツ;ギジュツ)'
|
||||||
W = 1
|
W = 1
|
||||||
; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)'
|
|
||||||
W^ = 1
|
|
||||||
; # 語の読みに関する知識レベルのいい間違い
|
; # 語の読みに関する知識レベルのいい間違い
|
||||||
; # 0 to use wrong form, 1 to use orthodox form
|
; # 0 to use wrong form, 1 to use orthodox form
|
||||||
; # Example: '(B シブタイ;ジュータイ)'
|
; # Example: '(B シブタイ;ジュータイ)'
|
||||||
B = 0
|
B = 0
|
||||||
; # Example: 'データー(B カズ;スー)'
|
|
||||||
B^ = 0
|
|
||||||
; # 笑いながら発話
|
; # 笑いながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)'
|
||||||
笑 = 0
|
笑 = 0
|
||||||
; # Example: 'コク(笑 サイ+(D オン))',
|
|
||||||
笑^ = 0
|
|
||||||
; # 泣きながら発話
|
; # 泣きながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(泣 ドンナニ)'
|
; # Example: '(泣 ドンナニ)'
|
||||||
泣 = 0
|
泣 = 0
|
||||||
泣^ = 0
|
|
||||||
; # 咳をしながら発話
|
; # 咳をしながら発話
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: 'シャ(咳 リン) ノ'
|
; # Example: 'シャ(咳 リン) ノ'
|
||||||
咳 = 0
|
咳 = 0
|
||||||
; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)'
|
|
||||||
咳^ = 0
|
|
||||||
; # ささやき声や独り言などの小さな声
|
; # ささやき声や独り言などの小さな声
|
||||||
; # 0 to remain, 1 to delete
|
; # 0 to remain, 1 to delete
|
||||||
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))'
|
||||||
L = 0
|
L = 0
|
||||||
; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト'
|
|
||||||
L^ = 0
|
|
||||||
|
|
||||||
[REPLACEMENTS]
|
|
||||||
; # ボーカルフライなどで母音が同定できない場合
|
|
||||||
<FV> =
|
|
||||||
; # 「うん/うーん/ふーん」の音の特定が困難な場合
|
|
||||||
<VN> =
|
|
||||||
; # 非語彙的な母音の引き延ばし
|
|
||||||
<H> =
|
|
||||||
; # 非語彙的な子音の引き延ばし
|
|
||||||
<Q> =
|
|
||||||
; # 言語音と独立に講演者の笑いが生じている場合
|
|
||||||
<笑> =
|
|
||||||
; # 言語音と独立に講演者の咳が生じている場合
|
|
||||||
<咳> =
|
|
||||||
; # 言語音と独立に講演者の息が生じている場合
|
|
||||||
<息> =
|
|
||||||
; # 講演者の泣き声
|
|
||||||
<泣> =
|
|
||||||
; # 聴衆(司会者なども含む)の発話
|
|
||||||
<フロア発話> =
|
|
||||||
; # 聴衆の笑い
|
|
||||||
<フロア笑> =
|
|
||||||
; # 聴衆の拍手
|
|
||||||
<拍手> =
|
|
||||||
; # 講演者が発表中に用いたデモンストレーションの音声
|
|
||||||
<デモ> =
|
|
||||||
; # 学会講演に発表時間を知らせるためにならすベルの音
|
|
||||||
<ベル> =
|
|
||||||
; # 転記単位全体が再度読み直された場合
|
|
||||||
<朗読間違い> =
|
|
||||||
; # 上記以外の音で特に目立った音
|
|
||||||
<雑音> =
|
|
||||||
; # 0.2秒以上のポーズ
|
|
||||||
<P> =
|
|
||||||
; # Redacted information, for R
|
|
||||||
; # It is \x00D7 multiplication sign, not your normal 'x'
|
|
||||||
× = ×
|
|
||||||
|
|
||||||
[FIELDS]
|
|
||||||
; # Time information for segment
|
|
||||||
time = 3
|
|
||||||
; # Word surface
|
|
||||||
surface = 5
|
|
||||||
; # Word surface root form without CSJ tags
|
|
||||||
notag = 9
|
|
||||||
; # Part Of Speech
|
|
||||||
pos1 = 11
|
|
||||||
; # Conjugated Form
|
|
||||||
cForm = 12
|
|
||||||
; # Conjugation Type
|
|
||||||
cType1 = 13
|
|
||||||
; # Subcategory of POS
|
|
||||||
pos2 = 14
|
|
||||||
; # Euphonic Change / Subcategory of Conjugation Type
|
|
||||||
cType2 = 15
|
|
||||||
; # Other information
|
|
||||||
other = 16
|
|
||||||
; # Pronunciation for lexicon
|
|
||||||
pron = 10
|
|
||||||
; # Speaker ID
|
|
||||||
spk_id = 2
|
|
||||||
|
|
||||||
[KATAKANA2ROMAJI]
|
|
||||||
ア = 'a
|
|
||||||
イ = 'i
|
|
||||||
ウ = 'u
|
|
||||||
エ = 'e
|
|
||||||
オ = 'o
|
|
||||||
カ = ka
|
|
||||||
キ = ki
|
|
||||||
ク = ku
|
|
||||||
ケ = ke
|
|
||||||
コ = ko
|
|
||||||
ガ = ga
|
|
||||||
ギ = gi
|
|
||||||
グ = gu
|
|
||||||
ゲ = ge
|
|
||||||
ゴ = go
|
|
||||||
サ = sa
|
|
||||||
シ = si
|
|
||||||
ス = su
|
|
||||||
セ = se
|
|
||||||
ソ = so
|
|
||||||
ザ = za
|
|
||||||
ジ = zi
|
|
||||||
ズ = zu
|
|
||||||
ゼ = ze
|
|
||||||
ゾ = zo
|
|
||||||
タ = ta
|
|
||||||
チ = ti
|
|
||||||
ツ = tu
|
|
||||||
テ = te
|
|
||||||
ト = to
|
|
||||||
ダ = da
|
|
||||||
ヂ = di
|
|
||||||
ヅ = du
|
|
||||||
デ = de
|
|
||||||
ド = do
|
|
||||||
ナ = na
|
|
||||||
ニ = ni
|
|
||||||
ヌ = nu
|
|
||||||
ネ = ne
|
|
||||||
ノ = no
|
|
||||||
ハ = ha
|
|
||||||
ヒ = hi
|
|
||||||
フ = hu
|
|
||||||
ヘ = he
|
|
||||||
ホ = ho
|
|
||||||
バ = ba
|
|
||||||
ビ = bi
|
|
||||||
ブ = bu
|
|
||||||
ベ = be
|
|
||||||
ボ = bo
|
|
||||||
パ = pa
|
|
||||||
ピ = pi
|
|
||||||
プ = pu
|
|
||||||
ペ = pe
|
|
||||||
ポ = po
|
|
||||||
マ = ma
|
|
||||||
ミ = mi
|
|
||||||
ム = mu
|
|
||||||
メ = me
|
|
||||||
モ = mo
|
|
||||||
ヤ = ya
|
|
||||||
ユ = yu
|
|
||||||
ヨ = yo
|
|
||||||
ラ = ra
|
|
||||||
リ = ri
|
|
||||||
ル = ru
|
|
||||||
レ = re
|
|
||||||
ロ = ro
|
|
||||||
ワ = wa
|
|
||||||
ヰ = we
|
|
||||||
ヱ = wi
|
|
||||||
ヲ = wo
|
|
||||||
ン = ŋ
|
|
||||||
ッ = q
|
|
||||||
ー = -
|
|
||||||
キャ = kǐa
|
|
||||||
キュ = kǐu
|
|
||||||
キョ = kǐo
|
|
||||||
ギャ = gǐa
|
|
||||||
ギュ = gǐu
|
|
||||||
ギョ = gǐo
|
|
||||||
シャ = sǐa
|
|
||||||
シュ = sǐu
|
|
||||||
ショ = sǐo
|
|
||||||
ジャ = zǐa
|
|
||||||
ジュ = zǐu
|
|
||||||
ジョ = zǐo
|
|
||||||
チャ = tǐa
|
|
||||||
チュ = tǐu
|
|
||||||
チョ = tǐo
|
|
||||||
ヂャ = dǐa
|
|
||||||
ヂュ = dǐu
|
|
||||||
ヂョ = dǐo
|
|
||||||
ニャ = nǐa
|
|
||||||
ニュ = nǐu
|
|
||||||
ニョ = nǐo
|
|
||||||
ヒャ = hǐa
|
|
||||||
ヒュ = hǐu
|
|
||||||
ヒョ = hǐo
|
|
||||||
ビャ = bǐa
|
|
||||||
ビュ = bǐu
|
|
||||||
ビョ = bǐo
|
|
||||||
ピャ = pǐa
|
|
||||||
ピュ = pǐu
|
|
||||||
ピョ = pǐo
|
|
||||||
ミャ = mǐa
|
|
||||||
ミュ = mǐu
|
|
||||||
ミョ = mǐo
|
|
||||||
リャ = rǐa
|
|
||||||
リュ = rǐu
|
|
||||||
リョ = rǐo
|
|
||||||
ァ = a
|
|
||||||
ィ = i
|
|
||||||
ゥ = u
|
|
||||||
ェ = e
|
|
||||||
ォ = o
|
|
||||||
ヮ = ʍ
|
|
||||||
ヴ = vu
|
|
||||||
ャ = ǐa
|
|
||||||
ュ = ǐu
|
|
||||||
ョ = ǐo
|
|
||||||
|
202
egs/csj/ASR/local/disfluent_recogs_to_fluent.py
Normal file
202
egs/csj/ASR/local/disfluent_recogs_to_fluent.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import kaldialign
|
||||||
|
from lhotse import CutSet
|
||||||
|
|
||||||
|
ARGPARSE_DESCRIPTION = """
|
||||||
|
This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript,
|
||||||
|
compares it against a fluent transcript, and saves the results in a separate directory.
|
||||||
|
This is useful to compare disfluent models with fluent models on the same metric.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
description=ARGPARSE_DESCRIPTION,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recogs",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cut",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--res-dir", type=Path, required=True, help="Path to save results"
|
||||||
|
)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def d2f(stats):
|
||||||
|
"""
|
||||||
|
Compare the outputs of a disfluent model against a fluent reference.
|
||||||
|
Indicates a disfluent model's performance only on the content words
|
||||||
|
|
||||||
|
CER^d_f = (sub_f + ins + del_f) / Nf
|
||||||
|
|
||||||
|
"""
|
||||||
|
return stats["base"] / stats["Nf"]
|
||||||
|
|
||||||
|
|
||||||
|
def calc_cer(refs, hyps):
|
||||||
|
subs = {
|
||||||
|
"F": 0,
|
||||||
|
"D": 0,
|
||||||
|
}
|
||||||
|
ins = 0
|
||||||
|
dels = {
|
||||||
|
"F": 0,
|
||||||
|
"D": 0,
|
||||||
|
}
|
||||||
|
cors = {
|
||||||
|
"F": 0,
|
||||||
|
"D": 0,
|
||||||
|
}
|
||||||
|
dis_ref_len = 0
|
||||||
|
flu_ref_len = 0
|
||||||
|
|
||||||
|
for ref, hyp in zip(refs, hyps):
|
||||||
|
assert (
|
||||||
|
ref[0] == hyp[0]
|
||||||
|
), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}."
|
||||||
|
tag = ref[2].copy()
|
||||||
|
ref = ref[1]
|
||||||
|
dis_ref_len += len(ref)
|
||||||
|
# Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively.
|
||||||
|
flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)])
|
||||||
|
hyp = hyp[1]
|
||||||
|
ali = kaldialign.align(ref, hyp, "*")
|
||||||
|
tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali]
|
||||||
|
for tag, (ref_word, hyp_word) in zip(tags, ali):
|
||||||
|
if "D" in tag or "F" in tag:
|
||||||
|
tag = "D"
|
||||||
|
else:
|
||||||
|
tag = "F"
|
||||||
|
|
||||||
|
if ref_word == "*":
|
||||||
|
ins += 1
|
||||||
|
elif hyp_word == "*":
|
||||||
|
dels[tag] += 1
|
||||||
|
elif ref_word != hyp_word:
|
||||||
|
subs[tag] += 1
|
||||||
|
else:
|
||||||
|
cors[tag] += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"subs": subs,
|
||||||
|
"ins": ins,
|
||||||
|
"dels": dels,
|
||||||
|
"cors": cors,
|
||||||
|
"dis_ref_len": dis_ref_len,
|
||||||
|
"flu_ref_len": flu_ref_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def for_each_recogs(recogs_file: Path, refs, out_dir):
|
||||||
|
hyps = []
|
||||||
|
with recogs_file.open() as fin:
|
||||||
|
for line in fin:
|
||||||
|
if "ref" in line:
|
||||||
|
continue
|
||||||
|
cutid, hyp = line.split(":\thyp=")
|
||||||
|
hyps.append((cutid, eval(hyp)))
|
||||||
|
|
||||||
|
assert len(refs) == len(
|
||||||
|
hyps
|
||||||
|
), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal."
|
||||||
|
stats = calc_cer(refs, hyps)
|
||||||
|
stat_table = ["tag,yes,no"]
|
||||||
|
|
||||||
|
for cer_type in ["subs", "dels", "cors", "ins"]:
|
||||||
|
ret = f"{cer_type}"
|
||||||
|
for df in ["D", "F"]:
|
||||||
|
try:
|
||||||
|
ret += f",{stats[cer_type][df]}"
|
||||||
|
except TypeError:
|
||||||
|
# insertions do not belong to F or D, and is not subscriptable.
|
||||||
|
ret += f",{stats[cer_type]},"
|
||||||
|
break
|
||||||
|
stat_table.append(ret)
|
||||||
|
stat_table = "\n".join(stat_table)
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"subd": stats["subs"]["D"],
|
||||||
|
"deld": stats["dels"]["D"],
|
||||||
|
"cord": stats["cors"]["D"],
|
||||||
|
"Nf": stats["flu_ref_len"],
|
||||||
|
"base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"],
|
||||||
|
}
|
||||||
|
|
||||||
|
cer = d2f(stats)
|
||||||
|
results = [
|
||||||
|
f"{cer:.2%}",
|
||||||
|
f"Nf,{stats['Nf']}",
|
||||||
|
]
|
||||||
|
results = "\n".join(results)
|
||||||
|
|
||||||
|
with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout:
|
||||||
|
fout.write(results)
|
||||||
|
fout.write("\n\n")
|
||||||
|
fout.write(stat_table)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
recogs_file: Path = args.recogs
|
||||||
|
assert (
|
||||||
|
recogs_file.is_file() or recogs_file.is_dir()
|
||||||
|
), f"recogs_file cannot be found at {recogs_file}."
|
||||||
|
|
||||||
|
args.res_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"):
|
||||||
|
assert (
|
||||||
|
"csj_cuts" in args.cut.name
|
||||||
|
), f"Expected {args.cut} to be a cuts manifest."
|
||||||
|
|
||||||
|
refs: CutSet = CutSet.from_file(args.cut)
|
||||||
|
refs = sorted(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
e.id,
|
||||||
|
list(e.supervisions[0].custom["disfluent"]),
|
||||||
|
e.supervisions[0].custom["disfluent_tag"].split(","),
|
||||||
|
)
|
||||||
|
for e in refs
|
||||||
|
],
|
||||||
|
key=lambda x: x[0],
|
||||||
|
)
|
||||||
|
for_each_recogs(recogs_file, refs, args.res_dir)
|
||||||
|
|
||||||
|
elif recogs_file.is_dir():
|
||||||
|
recogs_file_path = recogs_file
|
||||||
|
for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]:
|
||||||
|
refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz")
|
||||||
|
refs = sorted(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
r.id,
|
||||||
|
list(r.supervisions[0].custom["disfluent"]),
|
||||||
|
r.supervisions[0].custom["disfluent_tag"].split(","),
|
||||||
|
)
|
||||||
|
for r in refs
|
||||||
|
],
|
||||||
|
key=lambda x: x[0],
|
||||||
|
)
|
||||||
|
for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"):
|
||||||
|
for_each_recogs(recogs_file, refs, args.res_dir)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unrecognised recogs file provided: {recogs_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -45,8 +45,8 @@ def get_parser():
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser()
|
args = get_parser()
|
||||||
|
|
||||||
for path in args.manifest_dir.glob("csj_cuts_*.jsonl.gz"):
|
for part in ["eval1", "eval2", "eval3", "valid", "excluded", "train"]:
|
||||||
|
path = args.manifest_dir / f"csj_cuts_{part}.jsonl.gz"
|
||||||
cuts: CutSet = load_manifest(path)
|
cuts: CutSet = load_manifest(path)
|
||||||
|
|
||||||
print("\n---------------------------------\n")
|
print("\n---------------------------------\n")
|
||||||
@ -58,123 +58,271 @@ if __name__ == "__main__":
|
|||||||
main()
|
main()
|
||||||
|
|
||||||
"""
|
"""
|
||||||
## eval1
|
csj_cuts_eval1.jsonl.gz:
|
||||||
Cuts count: 1272
|
Cut statistics:
|
||||||
Total duration (hh:mm:ss): 01:50:07
|
╒═══════════════════════════╤══════════╕
|
||||||
Speech duration (hh:mm:ss): 01:50:07 (100.0%)
|
│ Cuts count: │ 1023 │
|
||||||
Duration statistics (seconds):
|
├───────────────────────────┼──────────┤
|
||||||
mean 5.2
|
│ Total duration (hh:mm:ss) │ 01:55:40 │
|
||||||
std 3.9
|
├───────────────────────────┼──────────┤
|
||||||
min 0.2
|
│ mean │ 6.8 │
|
||||||
25% 1.9
|
├───────────────────────────┼──────────┤
|
||||||
50% 4.0
|
│ std │ 2.7 │
|
||||||
75% 8.1
|
├───────────────────────────┼──────────┤
|
||||||
99% 14.3
|
│ min │ 0.2 │
|
||||||
99.5% 14.7
|
├───────────────────────────┼──────────┤
|
||||||
99.9% 16.0
|
│ 25% │ 4.9 │
|
||||||
max 16.9
|
├───────────────────────────┼──────────┤
|
||||||
Recordings available: 1272
|
│ 50% │ 7.7 │
|
||||||
Features available: 1272
|
├───────────────────────────┼──────────┤
|
||||||
Supervisions available: 1272
|
│ 75% │ 9.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.5% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.9% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ max │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Recordings available: │ 1023 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Features available: │ 0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Supervisions available: │ 1023 │
|
||||||
|
╘═══════════════════════════╧══════════╛
|
||||||
SUPERVISION custom fields:
|
SUPERVISION custom fields:
|
||||||
- fluent (in 1272 cuts)
|
Speech duration statistics:
|
||||||
- disfluent (in 1272 cuts)
|
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||||
- number (in 1272 cuts)
|
│ Total speech duration │ 01:55:40 │ 100.00% of recording │
|
||||||
- symbol (in 1272 cuts)
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total speaking time duration │ 01:55:40 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||||
|
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||||
|
|
||||||
## eval2
|
---------------------------------
|
||||||
Cuts count: 1292
|
|
||||||
Total duration (hh:mm:ss): 01:56:50
|
|
||||||
Speech duration (hh:mm:ss): 01:56:50 (100.0%)
|
|
||||||
Duration statistics (seconds):
|
|
||||||
mean 5.4
|
|
||||||
std 3.9
|
|
||||||
min 0.1
|
|
||||||
25% 2.1
|
|
||||||
50% 4.6
|
|
||||||
75% 8.6
|
|
||||||
99% 14.1
|
|
||||||
99.5% 15.2
|
|
||||||
99.9% 16.1
|
|
||||||
max 16.9
|
|
||||||
Recordings available: 1292
|
|
||||||
Features available: 1292
|
|
||||||
Supervisions available: 1292
|
|
||||||
SUPERVISION custom fields:
|
|
||||||
- fluent (in 1292 cuts)
|
|
||||||
- number (in 1292 cuts)
|
|
||||||
- symbol (in 1292 cuts)
|
|
||||||
- disfluent (in 1292 cuts)
|
|
||||||
|
|
||||||
## eval3
|
csj_cuts_eval2.jsonl.gz:
|
||||||
Cuts count: 1385
|
Cut statistics:
|
||||||
Total duration (hh:mm:ss): 01:19:21
|
╒═══════════════════════════╤══════════╕
|
||||||
Speech duration (hh:mm:ss): 01:19:21 (100.0%)
|
│ Cuts count: │ 1025 │
|
||||||
Duration statistics (seconds):
|
├───────────────────────────┼──────────┤
|
||||||
mean 3.4
|
│ Total duration (hh:mm:ss) │ 02:02:07 │
|
||||||
std 3.0
|
├───────────────────────────┼──────────┤
|
||||||
min 0.2
|
│ mean │ 7.1 │
|
||||||
25% 1.2
|
├───────────────────────────┼──────────┤
|
||||||
50% 2.5
|
│ std │ 2.5 │
|
||||||
75% 4.6
|
├───────────────────────────┼──────────┤
|
||||||
99% 12.7
|
│ min │ 0.1 │
|
||||||
99.5% 13.7
|
├───────────────────────────┼──────────┤
|
||||||
99.9% 15.0
|
│ 25% │ 5.9 │
|
||||||
max 15.9
|
├───────────────────────────┼──────────┤
|
||||||
Recordings available: 1385
|
│ 50% │ 7.9 │
|
||||||
Features available: 1385
|
├───────────────────────────┼──────────┤
|
||||||
Supervisions available: 1385
|
│ 75% │ 9.1 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.5% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.9% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ max │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Recordings available: │ 1025 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Features available: │ 0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Supervisions available: │ 1025 │
|
||||||
|
╘═══════════════════════════╧══════════╛
|
||||||
SUPERVISION custom fields:
|
SUPERVISION custom fields:
|
||||||
- number (in 1385 cuts)
|
Speech duration statistics:
|
||||||
- symbol (in 1385 cuts)
|
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||||
- fluent (in 1385 cuts)
|
│ Total speech duration │ 02:02:07 │ 100.00% of recording │
|
||||||
- disfluent (in 1385 cuts)
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total speaking time duration │ 02:02:07 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||||
|
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||||
|
|
||||||
## valid
|
---------------------------------
|
||||||
Cuts count: 4000
|
|
||||||
Total duration (hh:mm:ss): 05:08:09
|
|
||||||
Speech duration (hh:mm:ss): 05:08:09 (100.0%)
|
|
||||||
Duration statistics (seconds):
|
|
||||||
mean 4.6
|
|
||||||
std 3.8
|
|
||||||
min 0.1
|
|
||||||
25% 1.5
|
|
||||||
50% 3.4
|
|
||||||
75% 7.0
|
|
||||||
99% 13.8
|
|
||||||
99.5% 14.8
|
|
||||||
99.9% 16.0
|
|
||||||
max 17.3
|
|
||||||
Recordings available: 4000
|
|
||||||
Features available: 4000
|
|
||||||
Supervisions available: 4000
|
|
||||||
SUPERVISION custom fields:
|
|
||||||
- fluent (in 4000 cuts)
|
|
||||||
- symbol (in 4000 cuts)
|
|
||||||
- disfluent (in 4000 cuts)
|
|
||||||
- number (in 4000 cuts)
|
|
||||||
|
|
||||||
## train
|
csj_cuts_eval3.jsonl.gz:
|
||||||
Cuts count: 1291134
|
Cut statistics:
|
||||||
Total duration (hh:mm:ss): 1596:37:27
|
╒═══════════════════════════╤══════════╕
|
||||||
Speech duration (hh:mm:ss): 1596:37:27 (100.0%)
|
│ Cuts count: │ 865 │
|
||||||
Duration statistics (seconds):
|
├───────────────────────────┼──────────┤
|
||||||
mean 4.5
|
│ Total duration (hh:mm:ss) │ 01:26:44 │
|
||||||
std 3.6
|
├───────────────────────────┼──────────┤
|
||||||
min 0.0
|
│ mean │ 6.0 │
|
||||||
25% 1.6
|
├───────────────────────────┼──────────┤
|
||||||
50% 3.3
|
│ std │ 3.0 │
|
||||||
75% 6.4
|
├───────────────────────────┼──────────┤
|
||||||
99% 14.0
|
│ min │ 0.3 │
|
||||||
99.5% 14.8
|
├───────────────────────────┼──────────┤
|
||||||
99.9% 16.6
|
│ 25% │ 3.3 │
|
||||||
max 27.8
|
├───────────────────────────┼──────────┤
|
||||||
Recordings available: 1291134
|
│ 50% │ 6.8 │
|
||||||
Features available: 1291134
|
├───────────────────────────┼──────────┤
|
||||||
Supervisions available: 1291134
|
│ 75% │ 8.7 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.5% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.9% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ max │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Recordings available: │ 865 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Features available: │ 0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Supervisions available: │ 865 │
|
||||||
|
╘═══════════════════════════╧══════════╛
|
||||||
SUPERVISION custom fields:
|
SUPERVISION custom fields:
|
||||||
- disfluent (in 1291134 cuts)
|
Speech duration statistics:
|
||||||
- fluent (in 1291134 cuts)
|
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||||
- symbol (in 1291134 cuts)
|
│ Total speech duration │ 01:26:44 │ 100.00% of recording │
|
||||||
- number (in 1291134 cuts)
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total speaking time duration │ 01:26:44 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||||
|
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||||
|
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
csj_cuts_valid.jsonl.gz:
|
||||||
|
Cut statistics:
|
||||||
|
╒═══════════════════════════╤══════════╕
|
||||||
|
│ Cuts count: │ 3743 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Total duration (hh:mm:ss) │ 06:40:15 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ mean │ 6.4 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ std │ 3.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ min │ 0.1 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 25% │ 3.9 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 50% │ 7.4 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 75% │ 9.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.5% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.9% │ 10.1 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ max │ 11.8 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Recordings available: │ 3743 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Features available: │ 0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Supervisions available: │ 3743 │
|
||||||
|
╘═══════════════════════════╧══════════╛
|
||||||
|
SUPERVISION custom fields:
|
||||||
|
Speech duration statistics:
|
||||||
|
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||||
|
│ Total speech duration │ 06:40:15 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total speaking time duration │ 06:40:15 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||||
|
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||||
|
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
csj_cuts_excluded.jsonl.gz:
|
||||||
|
Cut statistics:
|
||||||
|
╒═══════════════════════════╤══════════╕
|
||||||
|
│ Cuts count: │ 980 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Total duration (hh:mm:ss) │ 00:56:06 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ mean │ 3.4 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ std │ 3.1 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ min │ 0.1 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 25% │ 0.8 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 50% │ 2.2 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 75% │ 5.8 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99% │ 9.9 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.5% │ 9.9 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ 99.9% │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ max │ 10.0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Recordings available: │ 980 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Features available: │ 0 │
|
||||||
|
├───────────────────────────┼──────────┤
|
||||||
|
│ Supervisions available: │ 980 │
|
||||||
|
╘═══════════════════════════╧══════════╛
|
||||||
|
SUPERVISION custom fields:
|
||||||
|
Speech duration statistics:
|
||||||
|
╒══════════════════════════════╤══════════╤══════════════════════╕
|
||||||
|
│ Total speech duration │ 00:56:06 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total speaking time duration │ 00:56:06 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼──────────┼──────────────────────┤
|
||||||
|
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||||
|
╘══════════════════════════════╧══════════╧══════════════════════╛
|
||||||
|
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
csj_cuts_train.jsonl.gz:
|
||||||
|
Cut statistics:
|
||||||
|
╒═══════════════════════════╤════════════╕
|
||||||
|
│ Cuts count: │ 914151 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ Total duration (hh:mm:ss) │ 1695:29:43 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ mean │ 6.7 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ std │ 2.9 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ min │ 0.1 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ 25% │ 4.6 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ 50% │ 7.5 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ 75% │ 8.9 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ 99% │ 11.0 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ 99.5% │ 11.0 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ 99.9% │ 11.1 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ max │ 18.0 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ Recordings available: │ 914151 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ Features available: │ 0 │
|
||||||
|
├───────────────────────────┼────────────┤
|
||||||
|
│ Supervisions available: │ 914151 │
|
||||||
|
╘═══════════════════════════╧════════════╛
|
||||||
|
SUPERVISION custom fields:
|
||||||
|
Speech duration statistics:
|
||||||
|
╒══════════════════════════════╤════════════╤══════════════════════╕
|
||||||
|
│ Total speech duration │ 1695:29:43 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼────────────┼──────────────────────┤
|
||||||
|
│ Total speaking time duration │ 1695:29:43 │ 100.00% of recording │
|
||||||
|
├──────────────────────────────┼────────────┼──────────────────────┤
|
||||||
|
│ Total silence duration │ 00:00:00 │ 0.00% of recording │
|
||||||
|
╘══════════════════════════════╧════════════╧══════════════════════╛
|
||||||
"""
|
"""
|
||||||
|
@ -21,24 +21,14 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
|
from lhotse.recipes.csj import CSJSDBParser
|
||||||
|
|
||||||
ARGPARSE_DESCRIPTION = """
|
ARGPARSE_DESCRIPTION = """
|
||||||
This script gathers all training transcripts of the specified {trans_mode} type
|
This script gathers all training transcripts, parses them in disfluent mode, and produces a token list that would be the output set of the ASR system.
|
||||||
and produces a token_list that would be output set of the ASR system.
|
|
||||||
|
|
||||||
It splits transcripts by whitespace into lists, then, for each word in the
|
It outputs 3 files into the lang directory:
|
||||||
list, if the word does not appear in the list of user-defined multicharacter
|
- tokens.txt: a list of tokens in the output set.
|
||||||
strings, it further splits that word into individual characters to be counted
|
- lang_type: a file that contains the string "char"
|
||||||
into the output token set.
|
|
||||||
|
|
||||||
It outputs 4 files into the lang directory:
|
|
||||||
- trans_mode: the name of transcript mode. If trans_mode was not specified,
|
|
||||||
this will be an empty file.
|
|
||||||
- userdef_string: a list of user defined strings that should not be split
|
|
||||||
further into individual characters. By default, it contains "<unk>", "<blk>",
|
|
||||||
"<sos/eos>"
|
|
||||||
- words_len: the total number of tokens in the output set.
|
|
||||||
- words.txt: a list of tokens in the output set. The length matches words_len.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -50,98 +40,52 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train-cut", type=Path, required=True, help="Path to the train cut"
|
"train_cut", metavar="train-cut", type=Path, help="Path to the train cut"
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--trans-mode",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help=(
|
|
||||||
"Name of the transcript mode to use. "
|
|
||||||
"If lang-dir is not set, this will also name the lang-dir"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--lang-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None,
|
default=Path("data/lang_char"),
|
||||||
help=(
|
help=(
|
||||||
"Name of lang dir. "
|
"Name of lang dir. "
|
||||||
"If not set, this will default to lang_char_{trans-mode}"
|
"If not set, this will default to lang_char_{trans-mode}"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--userdef-string",
|
|
||||||
type=Path,
|
|
||||||
default=None,
|
|
||||||
help="Multicharacter strings that do not need to be split",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
|
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.lang_dir:
|
sysdef_string = set(["<blk>", "<unk>", "<sos/eos>"])
|
||||||
p = "lang_char"
|
|
||||||
if args.trans_mode:
|
|
||||||
p += f"_{args.trans_mode}"
|
|
||||||
args.lang_dir = Path(p)
|
|
||||||
|
|
||||||
if args.userdef_string:
|
# Using disfluent parsing as fluent is a subset of disfluent
|
||||||
args.userdef_string = set(args.userdef_string.read_text().split())
|
parser = CSJSDBParser()
|
||||||
else:
|
|
||||||
args.userdef_string = set()
|
|
||||||
|
|
||||||
sysdef_string = ["<blk>", "<unk>", "<sos/eos>"]
|
token_set = set()
|
||||||
args.userdef_string.update(sysdef_string)
|
logging.info(f"Creating vocabulary from {args.train_cut}.")
|
||||||
|
train_cut: CutSet = CutSet.from_file(args.train_cut)
|
||||||
|
for cut in train_cut:
|
||||||
|
if "_sp" in cut.id:
|
||||||
|
continue
|
||||||
|
|
||||||
train_set: CutSet = CutSet.from_file(args.train_cut)
|
text: str = cut.supervisions[0].custom["raw"]
|
||||||
|
for w in parser.parse(text, sep=" ").split(" "):
|
||||||
words = set()
|
token_set.update(w)
|
||||||
logging.info(
|
|
||||||
f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode."
|
|
||||||
)
|
|
||||||
for cut in train_set:
|
|
||||||
try:
|
|
||||||
text: str = (
|
|
||||||
cut.supervisions[0].custom[args.trans_mode]
|
|
||||||
if args.trans_mode
|
|
||||||
else cut.supervisions[0].text
|
|
||||||
)
|
|
||||||
except KeyError:
|
|
||||||
raise KeyError(
|
|
||||||
f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}"
|
|
||||||
)
|
|
||||||
for t in text.split():
|
|
||||||
if t in args.userdef_string:
|
|
||||||
words.add(t)
|
|
||||||
else:
|
|
||||||
words.update(c for c in list(t))
|
|
||||||
|
|
||||||
words -= set(sysdef_string)
|
|
||||||
words = sorted(words)
|
|
||||||
words = ["<blk>"] + words + ["<unk>", "<sos/eos>"]
|
|
||||||
|
|
||||||
|
token_set = ["<blk>"] + sorted(token_set - sysdef_string) + ["<unk>", "<sos/eos>"]
|
||||||
args.lang_dir.mkdir(parents=True, exist_ok=True)
|
args.lang_dir.mkdir(parents=True, exist_ok=True)
|
||||||
(args.lang_dir / "words.txt").write_text(
|
(args.lang_dir / "tokens.txt").write_text(
|
||||||
"\n".join(f"{word}\t{i}" for i, word in enumerate(words))
|
"\n".join(f"{t}\t{i}" for i, t in enumerate(token_set))
|
||||||
)
|
)
|
||||||
|
|
||||||
(args.lang_dir / "words_len").write_text(f"{len(words)}")
|
(args.lang_dir / "lang_type").write_text("char")
|
||||||
|
|
||||||
(args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string))
|
|
||||||
|
|
||||||
(args.lang_dir / "trans_mode").write_text(args.trans_mode)
|
|
||||||
logging.info("Done.")
|
logging.info("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
462
egs/csj/ASR/local/utils/asr_datamodule.py
Normal file
462
egs/csj/ASR/local/utils/asr_datamodule.py
Normal file
@ -0,0 +1,462 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
|
DynamicBucketingSampler,
|
||||||
|
K2SpeechRecognitionDataset,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SingleCutSampler,
|
||||||
|
SpecAugment,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
|
AudioSamples,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
|
from lhotse.utils import fix_random_seed
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class _SeedWorkers:
|
||||||
|
def __init__(self, seed: int):
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def __call__(self, worker_id: int):
|
||||||
|
fix_random_seed(self.seed + worker_id)
|
||||||
|
|
||||||
|
|
||||||
|
class AsrVariableTranscriptDataset(K2SpeechRecognitionDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
transcript_mode: str = "",
|
||||||
|
return_cuts: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.transcript_mode = transcript_mode
|
||||||
|
self.return_cuts = True
|
||||||
|
self._return_cuts = return_cuts
|
||||||
|
|
||||||
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||||
|
batch = super().__getitem__(cuts)
|
||||||
|
|
||||||
|
if self.transcript_mode:
|
||||||
|
batch["supervisions"]["text"] = [
|
||||||
|
supervision.custom[self.transcript_mode]
|
||||||
|
for cut in batch["supervisions"]["cut"]
|
||||||
|
for supervision in cut.supervisions
|
||||||
|
]
|
||||||
|
|
||||||
|
if not self._return_cuts:
|
||||||
|
del batch["supervisions"]["cut"]
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class CSJAsrDataModule:
|
||||||
|
"""
|
||||||
|
DataModule for k2 ASR experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- augmentation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="ASR data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--transcript-mode",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Mode of transcript in supervision to use.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/manifests"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--musan-dir", type=Path, help="Path to directory with musan cuts. "
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=200.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="The number of buckets for the DynamicBucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--concatenate-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
|
"to minimize the amount of padding.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--duration-factor",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
|
"relative to the duration of the longest cut in a batch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--gap",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The amount of padding (in seconds) inserted between "
|
||||||
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
|
"noise augmentation is used.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--shuffle",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled (=default), the examples will be "
|
||||||
|
"shuffled for each epoch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--drop-last",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to drop last batch. Used by sampler.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--enable-spec-aug",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, use SpecAugment for training dataset.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--spec-aug-time-warp-factor",
|
||||||
|
type=int,
|
||||||
|
default=80,
|
||||||
|
help="Used only when --enable-spec-aug is True. "
|
||||||
|
"It specifies the factor for time warping in SpecAugment. "
|
||||||
|
"Larger values mean more warping. "
|
||||||
|
"A value less than 1 means to disable time warp.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--enable-musan",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, select noise from MUSAN and mix it"
|
||||||
|
"with training dataset. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
CutSet for training.
|
||||||
|
sampler_state_dict:
|
||||||
|
The state dict for the training sampler.
|
||||||
|
"""
|
||||||
|
transforms = []
|
||||||
|
if self.args.enable_musan:
|
||||||
|
logging.info("Enable MUSAN")
|
||||||
|
logging.info("About to get Musan cuts")
|
||||||
|
cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz")
|
||||||
|
transforms.append(
|
||||||
|
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Disable MUSAN")
|
||||||
|
|
||||||
|
if self.args.concatenate_cuts:
|
||||||
|
logging.info(
|
||||||
|
f"Using cut concatenation with duration factor "
|
||||||
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
|
)
|
||||||
|
# Cut concatenation should be the first transform in the list,
|
||||||
|
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||||
|
# different utterances.
|
||||||
|
transforms = [
|
||||||
|
CutConcatenate(
|
||||||
|
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||||
|
)
|
||||||
|
] + transforms
|
||||||
|
|
||||||
|
input_transforms = []
|
||||||
|
if self.args.enable_spec_aug:
|
||||||
|
logging.info("Enable SpecAugment")
|
||||||
|
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||||
|
# Set the value of num_frame_masks according to Lhotse's version.
|
||||||
|
# In different Lhotse's versions, the default of num_frame_masks is
|
||||||
|
# different.
|
||||||
|
num_frame_masks = 10
|
||||||
|
num_frame_masks_parameter = inspect.signature(
|
||||||
|
SpecAugment.__init__
|
||||||
|
).parameters["num_frame_masks"]
|
||||||
|
if num_frame_masks_parameter.default == 1:
|
||||||
|
num_frame_masks = 2
|
||||||
|
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||||
|
input_transforms.append(
|
||||||
|
SpecAugment(
|
||||||
|
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||||
|
num_frame_masks=num_frame_masks,
|
||||||
|
features_mask_size=27,
|
||||||
|
num_feature_masks=2,
|
||||||
|
frames_mask_size=100,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Disable SpecAugment")
|
||||||
|
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
train = AsrVariableTranscriptDataset(
|
||||||
|
input_strategy=eval(self.args.input_strategy)(),
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_transforms=input_transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
transcript_mode=self.args.transcript_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
# NOTE: the PerturbSpeed transform should be added only if we
|
||||||
|
# remove it from data prep stage.
|
||||||
|
# Add on-the-fly speed perturbation; since originally it would
|
||||||
|
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||||
|
# 3x more epochs.
|
||||||
|
# Speed perturbation probably should come first before
|
||||||
|
# concatenation, but in principle the transforms order doesn't have
|
||||||
|
# to be strict (e.g. could be randomized)
|
||||||
|
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||||
|
# Drop feats to be on the safe side.
|
||||||
|
train = AsrVariableTranscriptDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
|
input_transforms=input_transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
transcript_mode=self.args.transcript_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using DynamicBucketingSampler.")
|
||||||
|
train_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
drop_last=self.args.drop_last,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SingleCutSampler.")
|
||||||
|
train_sampler = SingleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=self.args.shuffle,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
|
||||||
|
if sampler_state_dict is not None:
|
||||||
|
logging.info("Loading sampler state dict")
|
||||||
|
train_sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
|
# 'seed' is derived from the current random state, which will have
|
||||||
|
# previously been set in the main process.
|
||||||
|
seed = torch.randint(0, 100000, ()).item()
|
||||||
|
worker_init_fn = _SeedWorkers(seed)
|
||||||
|
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=False,
|
||||||
|
worker_init_fn=worker_init_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
transforms = []
|
||||||
|
if self.args.concatenate_cuts:
|
||||||
|
transforms = [
|
||||||
|
CutConcatenate(
|
||||||
|
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||||
|
)
|
||||||
|
] + transforms
|
||||||
|
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
validate = AsrVariableTranscriptDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
transcript_mode=self.args.transcript_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = AsrVariableTranscriptDataset(
|
||||||
|
cut_transforms=transforms,
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
transcript_mode=self.args.transcript_mode,
|
||||||
|
)
|
||||||
|
valid_sampler = DynamicBucketingSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
logging.info("About to create dev dataloader")
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return valid_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.debug("About to create test dataset")
|
||||||
|
|
||||||
|
test = AsrVariableTranscriptDataset(
|
||||||
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
|
if self.args.on_the_fly_feats
|
||||||
|
else eval(self.args.input_strategy)(),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
transcript_mode=self.args.transcript_mode,
|
||||||
|
)
|
||||||
|
sampler = DynamicBucketingSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.debug("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_train.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get valid cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_valid.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def excluded_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get excluded cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_excluded.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def eval1_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get eval1 cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval1.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def eval2_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get eval2 cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval2.jsonl.gz")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def eval3_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get eval3 cuts")
|
||||||
|
return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval3.jsonl.gz")
|
253
egs/csj/ASR/local/utils/tokenizer.py
Normal file
253
egs/csj/ASR/local/utils/tokenizer.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
from k2 import SymbolTable
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
text2word: Callable[[str], List[str]]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_arguments(parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(title="Lang related options")
|
||||||
|
|
||||||
|
group.add_argument("--lang", type=Path, help="Path to lang directory.")
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--lang-type",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. "
|
||||||
|
"Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Load(lang_dir: Path, lang_type="", oov="<unk>"):
|
||||||
|
|
||||||
|
if not lang_type:
|
||||||
|
assert (lang_dir / "lang_type").exists(), "lang_type not specified."
|
||||||
|
lang_type = (lang_dir / "lang_type").read_text().strip()
|
||||||
|
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
if lang_type == "bpe":
|
||||||
|
assert (
|
||||||
|
lang_dir / "bpe.model"
|
||||||
|
).exists(), f"No BPE .model could be found in {lang_dir}."
|
||||||
|
tokenizer = spm.SentencePieceProcessor()
|
||||||
|
tokenizer.Load(str(lang_dir / "bpe.model"))
|
||||||
|
elif lang_type == "char":
|
||||||
|
tokenizer = CharTokenizer(lang_dir, oov=oov)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{lang_type} not supported at the moment.")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
load = Load
|
||||||
|
|
||||||
|
def PieceToId(self, piece: str) -> int:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You need to implement this function in the child class."
|
||||||
|
)
|
||||||
|
|
||||||
|
piece_to_id = PieceToId
|
||||||
|
|
||||||
|
def IdToPiece(self, id: int) -> str:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You need to implement this function in the child class."
|
||||||
|
)
|
||||||
|
|
||||||
|
id_to_piece = IdToPiece
|
||||||
|
|
||||||
|
def GetPieceSize(self) -> int:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You need to implement this function in the child class."
|
||||||
|
)
|
||||||
|
|
||||||
|
get_piece_size = GetPieceSize
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.get_piece_size()
|
||||||
|
|
||||||
|
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You need to implement this function in the child class."
|
||||||
|
)
|
||||||
|
|
||||||
|
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You need to implement this function in the child class."
|
||||||
|
)
|
||||||
|
|
||||||
|
def EncodeAsIds(self, input: str) -> List[int]:
|
||||||
|
return self.EncodeAsIdsBatch([input])[0]
|
||||||
|
|
||||||
|
def EncodeAsPieces(self, input: str) -> List[str]:
|
||||||
|
return self.EncodeAsPiecesBatch([input])[0]
|
||||||
|
|
||||||
|
def Encode(
|
||||||
|
self, input: Union[str, List[str]], out_type=int
|
||||||
|
) -> Union[List, List[List]]:
|
||||||
|
if not input:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if isinstance(input, list):
|
||||||
|
if out_type is int:
|
||||||
|
return self.EncodeAsIdsBatch(input)
|
||||||
|
if out_type is str:
|
||||||
|
return self.EncodeAsPiecesBatch(input)
|
||||||
|
|
||||||
|
if out_type is int:
|
||||||
|
return self.EncodeAsIds(input)
|
||||||
|
if out_type is str:
|
||||||
|
return self.EncodeAsPieces(input)
|
||||||
|
|
||||||
|
encode = Encode
|
||||||
|
|
||||||
|
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You need to implement this function in the child class."
|
||||||
|
)
|
||||||
|
|
||||||
|
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You need to implement this function in the child class."
|
||||||
|
)
|
||||||
|
|
||||||
|
def DecodeIds(self, input: List[int]) -> str:
|
||||||
|
return self.DecodeIdsBatch([input])[0]
|
||||||
|
|
||||||
|
def DecodePieces(self, input: List[str]) -> str:
|
||||||
|
return self.DecodePiecesBatch([input])[0]
|
||||||
|
|
||||||
|
def Decode(
|
||||||
|
self,
|
||||||
|
input: Union[int, List[int], List[str], List[List[int]], List[List[str]]],
|
||||||
|
) -> Union[List[str], str]:
|
||||||
|
|
||||||
|
if not input:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if isinstance(input, int):
|
||||||
|
return self.id_to_piece(input)
|
||||||
|
elif isinstance(input, str):
|
||||||
|
raise TypeError(
|
||||||
|
"Unlike spm.SentencePieceProcessor, cannot decode from type str."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(input[0], list):
|
||||||
|
if not input[0] or isinstance(input[0][0], int):
|
||||||
|
return self.DecodeIdsBatch(input)
|
||||||
|
|
||||||
|
if isinstance(input[0][0], str):
|
||||||
|
return self.DecodePiecesBatch(input)
|
||||||
|
|
||||||
|
if isinstance(input[0], int):
|
||||||
|
return self.DecodeIds(input)
|
||||||
|
if isinstance(input[0], str):
|
||||||
|
return self.DecodePieces(input)
|
||||||
|
|
||||||
|
raise RuntimeError("Unknown input type")
|
||||||
|
|
||||||
|
decode = Decode
|
||||||
|
|
||||||
|
def SplitBatch(self, input: List[str]) -> List[List[str]]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"You need to implement this function in the child class."
|
||||||
|
)
|
||||||
|
|
||||||
|
def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]:
|
||||||
|
if isinstance(input, list):
|
||||||
|
return self.SplitBatch(input)
|
||||||
|
elif isinstance(input, str):
|
||||||
|
return self.SplitBatch([input])[0]
|
||||||
|
raise RuntimeError("Unknown input type")
|
||||||
|
|
||||||
|
split = Split
|
||||||
|
|
||||||
|
|
||||||
|
class CharTokenizer(Tokenizer):
|
||||||
|
def __init__(self, lang_dir: Path, oov="<unk>", sep=""):
|
||||||
|
assert (
|
||||||
|
lang_dir / "tokens.txt"
|
||||||
|
).exists(), f"tokens.txt could not be found in {lang_dir}."
|
||||||
|
token_table = SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||||
|
assert (
|
||||||
|
"#0" not in token_table
|
||||||
|
), "This tokenizer does not support disambig symbols."
|
||||||
|
self._id2sym = token_table._id2sym
|
||||||
|
self._sym2id = token_table._sym2id
|
||||||
|
self.oov = oov
|
||||||
|
self.oov_id = self._sym2id[oov]
|
||||||
|
self.sep = sep
|
||||||
|
if self.sep:
|
||||||
|
self.text2word = lambda x: x.split(self.sep)
|
||||||
|
else:
|
||||||
|
self.text2word = lambda x: list(x.replace(" ", ""))
|
||||||
|
|
||||||
|
def piece_to_id(self, piece: str) -> int:
|
||||||
|
try:
|
||||||
|
return self._sym2id[piece]
|
||||||
|
except KeyError:
|
||||||
|
return self.oov_id
|
||||||
|
|
||||||
|
def id_to_piece(self, id: int) -> str:
|
||||||
|
return self._id2sym[id]
|
||||||
|
|
||||||
|
def get_piece_size(self) -> int:
|
||||||
|
return len(self._sym2id)
|
||||||
|
|
||||||
|
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
|
||||||
|
return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input]
|
||||||
|
|
||||||
|
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
|
||||||
|
return [
|
||||||
|
[i if i in self._sym2id else self.oov for i in self.text2word(text)]
|
||||||
|
for text in input
|
||||||
|
]
|
||||||
|
|
||||||
|
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
|
||||||
|
return [self.sep.join(self.id_to_piece(i) for i in text) for text in input]
|
||||||
|
|
||||||
|
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
|
||||||
|
return [self.sep.join(text) for text in input]
|
||||||
|
|
||||||
|
def SplitBatch(self, input: List[str]) -> List[List[str]]:
|
||||||
|
return [self.text2word(text) for text in input]
|
||||||
|
|
||||||
|
|
||||||
|
def test_CharTokenizer():
|
||||||
|
test_single_string = "こんにちは"
|
||||||
|
test_multiple_string = [
|
||||||
|
"今日はいい天気ですよね",
|
||||||
|
"諏訪湖は綺麗でしょう",
|
||||||
|
"这在词表外",
|
||||||
|
"分かち 書き に し た 文章 です",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
test_empty_string = ""
|
||||||
|
sp = Tokenizer.load(Path("lang_char"), "char", oov="<unk>")
|
||||||
|
splitter = sp.split
|
||||||
|
print(sp.encode(test_single_string, out_type=str))
|
||||||
|
print(sp.encode(test_single_string, out_type=int))
|
||||||
|
print(sp.encode(test_multiple_string, out_type=str))
|
||||||
|
print(sp.encode(test_multiple_string, out_type=int))
|
||||||
|
print(sp.encode(test_empty_string, out_type=str))
|
||||||
|
print(sp.encode(test_empty_string, out_type=int))
|
||||||
|
print(sp.decode(sp.encode(test_single_string, out_type=str)))
|
||||||
|
print(sp.decode(sp.encode(test_single_string, out_type=int)))
|
||||||
|
print(sp.decode(sp.encode(test_multiple_string, out_type=str)))
|
||||||
|
print(sp.decode(sp.encode(test_multiple_string, out_type=int)))
|
||||||
|
print(sp.decode(sp.encode(test_empty_string, out_type=str)))
|
||||||
|
print(sp.decode(sp.encode(test_empty_string, out_type=int)))
|
||||||
|
print(splitter(test_single_string))
|
||||||
|
print(splitter(test_multiple_string))
|
||||||
|
print(splitter(test_empty_string))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_CharTokenizer()
|
@ -32,7 +32,7 @@
|
|||||||
# - speech
|
# - speech
|
||||||
#
|
#
|
||||||
# By default, this script produces the original transcript like kaldi and espnet. Optionally, you
|
# By default, this script produces the original transcript like kaldi and espnet. Optionally, you
|
||||||
# can generate other transcript formats by supplying your own config files. A few examples of these
|
# can add other transcript formats by supplying your own config files. A few examples of these
|
||||||
# config files can be found in local/conf.
|
# config files can be found in local/conf.
|
||||||
|
|
||||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
@ -44,10 +44,10 @@ nj=8
|
|||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=100
|
stop_stage=100
|
||||||
|
|
||||||
csj_dir=/mnt/minami_data_server/t2131178/corpus/CSJ
|
csj_dir=/mnt/host/corpus/csj
|
||||||
musan_dir=/mnt/minami_data_server/t2131178/corpus/musan/musan
|
musan_dir=/mnt/host/corpus/musan/musan
|
||||||
trans_dir=$csj_dir/retranscript
|
trans_dir=$csj_dir/transcript
|
||||||
csj_fbank_dir=/mnt/host/csj_data/fbank
|
csj_fbank_dir=/mnt/host/corpus/csj/fbank
|
||||||
musan_fbank_dir=$musan_dir/fbank
|
musan_fbank_dir=$musan_dir/fbank
|
||||||
csj_manifest_dir=data/manifests
|
csj_manifest_dir=data/manifests
|
||||||
musan_manifest_dir=$musan_dir/manifests
|
musan_manifest_dir=$musan_dir/manifests
|
||||||
@ -63,12 +63,8 @@ log() {
|
|||||||
|
|
||||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
log "Stage 1: Prepare CSJ manifest"
|
log "Stage 1: Prepare CSJ manifest"
|
||||||
# If you want to generate more transcript modes, append the path to those config files at c.
|
|
||||||
# Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini
|
|
||||||
# NOTE: In case multiple config files are supplied, the second config file and onwards will inherit
|
|
||||||
# the segment boundaries of the first config file.
|
|
||||||
if [ ! -e $csj_manifest_dir/.csj.done ]; then
|
if [ ! -e $csj_manifest_dir/.csj.done ]; then
|
||||||
lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4
|
lhotse prepare csj $csj_dir $csj_manifest_dir -t $trans_dir -j 16
|
||||||
touch $csj_manifest_dir/.csj.done
|
touch $csj_manifest_dir/.csj.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
@ -88,32 +84,24 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \
|
python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \
|
||||||
--fbank-dir $csj_fbank_dir
|
--fbank-dir $csj_fbank_dir
|
||||||
parts=(
|
parts=(
|
||||||
train
|
|
||||||
valid
|
|
||||||
eval1
|
eval1
|
||||||
eval2
|
eval2
|
||||||
eval3
|
eval3
|
||||||
|
valid
|
||||||
|
excluded
|
||||||
|
train
|
||||||
)
|
)
|
||||||
for part in ${parts[@]}; do
|
for part in ${parts[@]}; do
|
||||||
python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz
|
python local/validate_manifest.py --manifest $csj_fbank_dir/csj_cuts_$part.jsonl.gz
|
||||||
done
|
done
|
||||||
touch $csj_fbank_dir/.csj-validated.done
|
touch $csj_fbank_dir/.csj-validated.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
log "Stage 4: Prepare CSJ lang"
|
log "Stage 4: Prepare CSJ lang_char"
|
||||||
modes=disfluent
|
python local/prepare_lang_char.py $csj_fbank_dir/csj_cuts_train.jsonl.gz
|
||||||
|
python local/add_transcript_mode.py -f $csj_fbank_dir -c local/conf/fluent.ini local/conf/number.ini
|
||||||
# If you want prepare the lang directory for other transcript modes, just append
|
|
||||||
# the names of those modes behind. An example is shown as below:-
|
|
||||||
# modes="$modes fluent symbol number"
|
|
||||||
|
|
||||||
for mode in ${modes[@]}; do
|
|
||||||
python local/prepare_lang_char.py --trans-mode $mode \
|
|
||||||
--train-cut $csj_manifest_dir/csj_cuts_train.jsonl.gz \
|
|
||||||
--lang-dir lang_char_$mode
|
|
||||||
done
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
@ -128,6 +116,6 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
log "Stage 6: Show manifest statistics"
|
log "Stage 6: Show manifest statistics"
|
||||||
python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt
|
python local/display_manifest_statistics.py --manifest-dir $csj_fbank_dir > $csj_fbank_dir/manifest_statistics.txt
|
||||||
cat $csj_manifest_dir/manifest_statistics.txt
|
cat $csj_fbank_dir/manifest_statistics.txt
|
||||||
fi
|
fi
|
||||||
|
@ -0,0 +1,76 @@
|
|||||||
|
import logging
|
||||||
|
from configparser import ConfigParser
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def escape_html(text: str):
|
||||||
|
"""
|
||||||
|
Escapes all html characters in text
|
||||||
|
:param str text:
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramStreamIO(logging.Handler):
|
||||||
|
|
||||||
|
API_ENDPOINT = "https://api.telegram.org"
|
||||||
|
MAX_MESSAGE_LEN = 4096
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s - %(levelname)s at %(funcName)s "
|
||||||
|
"(line %(lineno)s):\n\n%(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, tg_configfile: str):
|
||||||
|
super(TelegramStreamIO, self).__init__()
|
||||||
|
config = ConfigParser()
|
||||||
|
if not config.read(tg_configfile):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"{tg_configfile} not found. " "Retry without --telegram-cred flag."
|
||||||
|
)
|
||||||
|
config = config["TELEGRAM"]
|
||||||
|
token = config["token"]
|
||||||
|
self.chat_id = config["chat_id"]
|
||||||
|
self.url = f"{self.API_ENDPOINT}/bot{token}/sendMessage"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup_logger(params):
|
||||||
|
if not params.telegram_cred:
|
||||||
|
return
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
f"{params.exp_dir.name} %(asctime)s \n%(message)s"
|
||||||
|
)
|
||||||
|
tg = TelegramStreamIO(params.telegram_cred)
|
||||||
|
tg.setLevel(logging.WARN)
|
||||||
|
tg.setFormatter(formatter)
|
||||||
|
logging.getLogger("").addHandler(tg)
|
||||||
|
|
||||||
|
def emit(self, record: logging.LogRecord):
|
||||||
|
"""
|
||||||
|
Emit a record.
|
||||||
|
Send the record to the Web server as a percent-encoded dictionary
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"chat_id": self.chat_id,
|
||||||
|
"text": self.format(self.mapLogRecord(record)),
|
||||||
|
"parse_mode": "HTML",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
requests.get(self.url, json=data)
|
||||||
|
# return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to send telegram message: {repr(e)}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def mapLogRecord(self, record):
|
||||||
|
"""
|
||||||
|
Default implementation of mapping the log record into a dict
|
||||||
|
that is sent as the CGI data. Overwrite in your class.
|
||||||
|
Contributed by Franz Glasner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for k, v in record.__dict__.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
setattr(record, k, escape_html(v))
|
||||||
|
return record
|
@ -0,0 +1 @@
|
|||||||
|
../local/utils/asr_datamodule.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py
|
852
egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
852
egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py
Executable file
@ -0,0 +1,852 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
(1) greedy search
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
|
||||||
|
(2) beam search (not recommended)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method beam_search \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--beam-size 4
|
||||||
|
|
||||||
|
(4) fast beam search (one best)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method fast_beam_search \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--max-states 64
|
||||||
|
|
||||||
|
(5) fast beam search (nbest)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method fast_beam_search_nbest \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(6) fast beam search (nbest oracle WER)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--max-states 64 \
|
||||||
|
--num-paths 200 \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--nbest-scale 0.5
|
||||||
|
|
||||||
|
(7) fast beam search (with LG)
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
|
--beam 20.0 \
|
||||||
|
--max-contexts 8 \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--max-states 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import CSJAsrDataModule
|
||||||
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG,
|
||||||
|
fast_beam_search_nbest_oracle,
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless7_streaming/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--res-dir",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="The path to save results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir. It should contain at least a word table.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
- fast_beam_search_nbest
|
||||||
|
- fast_beam_search_nbest_oracle
|
||||||
|
- fast_beam_search_nbest_LG
|
||||||
|
If you use fast_beam_search_nbest_LG, you have to specify
|
||||||
|
`--lang-dir`, which should contain `LG.pt`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-graph",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=20.0,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search,
|
||||||
|
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.01,
|
||||||
|
help="""
|
||||||
|
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||||
|
and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding_method is greedy_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="""Number of paths for nbest decoding.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""Scale applied to lattice scores when computing nbest paths.
|
||||||
|
Used only when the decoding method is fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--pad-feature",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""
|
||||||
|
Number of frames to pad at the end.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: Tokenizer,
|
||||||
|
batch: dict,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if greedy_search is used, it would be "greedy_search"
|
||||||
|
If beam search with a beam size of 7 is used, it would be
|
||||||
|
"beam_7"
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.pad_feature:
|
||||||
|
feature_lens += params.pad_feature
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.pad_feature),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
|
hyps = []
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(sp.text2word(hyp))
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_LG(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append([word_table[i] for i in hyp])
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest":
|
||||||
|
hyp_tokens = fast_beam_search_nbest(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(sp.text2word(hyp))
|
||||||
|
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||||
|
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=sp.encode(supervisions["text"]),
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(sp.text2word(hyp))
|
||||||
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(sp.text2word(hyp))
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(sp.text2word(hyp))
|
||||||
|
else:
|
||||||
|
batch_size = encoder_out.size(0)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
# fmt: off
|
||||||
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
|
# fmt: on
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "beam_search":
|
||||||
|
hyp = beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
hyps.append(sp.text2word(sp.decode(hyp)))
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
return {"greedy_search": hyps}
|
||||||
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
|
key = f"beam_{params.beam}_"
|
||||||
|
key += f"max_contexts_{params.max_contexts}_"
|
||||||
|
key += f"max_states_{params.max_states}"
|
||||||
|
if "nbest" in params.decoding_method:
|
||||||
|
key += f"_num_paths_{params.num_paths}_"
|
||||||
|
key += f"nbest_scale_{params.nbest_scale}"
|
||||||
|
if "LG" in params.decoding_method:
|
||||||
|
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||||
|
|
||||||
|
return {key: hyps}
|
||||||
|
else:
|
||||||
|
return {f"beam_size_{params.beam_size}": hyps}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: Tokenizer,
|
||||||
|
word_table: Optional[k2.SymbolTable] = None,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||||
|
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
log_interval = 50
|
||||||
|
else:
|
||||||
|
log_interval = 20
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
word_table=word_table,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = sp.text2word(ref_text)
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = (
|
||||||
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = (
|
||||||
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
return test_set_wers
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
CSJAsrDataModule.add_arguments(parser)
|
||||||
|
Tokenizer.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"fast_beam_search",
|
||||||
|
"fast_beam_search_nbest",
|
||||||
|
"fast_beam_search_nbest_LG",
|
||||||
|
"fast_beam_search_nbest_oracle",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
|
if not params.res_dir:
|
||||||
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||||
|
|
||||||
|
if "fast_beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
if "nbest" in params.decoding_method:
|
||||||
|
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||||
|
params.suffix += f"-num-paths-{params.num_paths}"
|
||||||
|
if "LG" in params.decoding_method:
|
||||||
|
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||||
|
elif "beam_search" in params.decoding_method:
|
||||||
|
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", params.gpu)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||||
|
|
||||||
|
# <blk> and <unk> are defined in local/prepare_lang_char.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||||
|
model.encoder.decode_chunk_size,
|
||||||
|
params.decode_chunk_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
decoding_graph = None
|
||||||
|
word_table = None
|
||||||
|
|
||||||
|
if params.decoding_graph:
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.decoding_graph, map_location=device)
|
||||||
|
)
|
||||||
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
|
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
word_table = lexicon.word_table
|
||||||
|
lg_filename = params.lang_dir / "LG.pt"
|
||||||
|
logging.info(f"Loading {lg_filename}")
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(lg_filename, map_location=device)
|
||||||
|
)
|
||||||
|
decoding_graph.scores *= params.ngram_lm_scale
|
||||||
|
else:
|
||||||
|
word_table = None
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
csj_corpus = CSJAsrDataModule(args)
|
||||||
|
|
||||||
|
for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]:
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=csj_corpus.test_dataloaders(getattr(csj_corpus, f"{subdir}_cuts")()),
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
)
|
||||||
|
tot_err = save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=subdir,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
with (
|
||||||
|
params.res_dir
|
||||||
|
/ (
|
||||||
|
f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
|
||||||
|
f"_{params.avg}_{params.epoch}.cer"
|
||||||
|
)
|
||||||
|
).open("w") as fout:
|
||||||
|
if len(tot_err) == 1:
|
||||||
|
fout.write(f"{tot_err[0][1]}")
|
||||||
|
else:
|
||||||
|
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
|
313
egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py
Normal file
313
egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
(1) Export to torchscript model using torch.jit.script()
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 9 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||||
|
|
||||||
|
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||||
|
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||||
|
|
||||||
|
Check
|
||||||
|
https://github.com/k2-fsa/sherpa
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
|
(2) Export `model.state_dict()`
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||||
|
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
|
||||||
|
To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/csj/ASR
|
||||||
|
./pruned_transducer_stateless7_streaming/decode.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
|
--lang data/lang_char
|
||||||
|
|
||||||
|
Check ./pretrained.py for its usage.
|
||||||
|
|
||||||
|
Note: If you don't want to train a model from scratch, we have
|
||||||
|
provided one for you. You can get it at
|
||||||
|
|
||||||
|
https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208
|
||||||
|
|
||||||
|
with the following commands:
|
||||||
|
|
||||||
|
sudo apt-get install git-lfs
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208
|
||||||
|
# You will find the pre-trained model in icefall-asr-csj-pruned-transducer-stateless7-230208/exp_fluent
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless7_streaming/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
It will generate a file named cpu_jit.pt
|
||||||
|
|
||||||
|
Check ./jit_pretrained.py for how to use it.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
Tokenizer.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||||
|
|
||||||
|
# <blk> is defined in local/prepare_lang_char.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.jit is True:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
model.save(str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torchscript. Export model.state_dict()")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -0,0 +1,308 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
# use -O to skip assertions and avoid some of the TracerWarnings
|
||||||
|
python -O pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 10 \
|
||||||
|
--use-averaged-model=True \
|
||||||
|
--decode-chunk-len 32
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless2/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_jit_trace(
|
||||||
|
encoder_model: torch.nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
params: AttributeDict,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
decode_chunk_len = params.decode_chunk_len # before subsampling
|
||||||
|
pad_length = 7
|
||||||
|
s = f"decode_chunk_len: {decode_chunk_len}"
|
||||||
|
logging.info(s)
|
||||||
|
assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
|
||||||
|
encoder_model.decode_chunk_size,
|
||||||
|
decode_chunk_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = decode_chunk_len + pad_length
|
||||||
|
|
||||||
|
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||||
|
states = encoder_model.get_init_state(device=x.device)
|
||||||
|
|
||||||
|
encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
|
||||||
|
traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
|
||||||
|
traced_model.save(encoder_filename)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_jit_trace(
|
||||||
|
decoder_model: torch.nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given decoder model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The input decoder model
|
||||||
|
decoder_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(decoder_model, (y, need_pad))
|
||||||
|
traced_model.save(decoder_filename)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_jit_trace(
|
||||||
|
joiner_model: torch.nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given joiner model with torch.jit.trace()
|
||||||
|
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
joiner_model:
|
||||||
|
The input joiner model
|
||||||
|
joiner_filename:
|
||||||
|
The filename to save the exported model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
|
||||||
|
traced_model.save(joiner_filename)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
Tokenizer.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||||
|
|
||||||
|
# <blk> is defined in local/prepare_lang_char.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
logging.info("Using torch.jit.trace()")
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
|
||||||
|
export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
|
||||||
|
export_decoder_model_jit_trace(model.decoder, decoder_filename)
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
|
||||||
|
export_joiner_model_jit_trace(model.joiner, joiner_filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -0,0 +1,286 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# flake8: noqa
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads torchscript models exported by `torch.jit.trace()`
|
||||||
|
and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/jit_trace_export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 10 \
|
||||||
|
--use-averaged-model=True \
|
||||||
|
--decode-chunk-len 32
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \
|
||||||
|
--encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \
|
||||||
|
--decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \
|
||||||
|
--joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner torchscript model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-chunk-len",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="The chunk size for decoding (in frames before subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_file",
|
||||||
|
type=str,
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
decoder: torch.jit.ScriptModule,
|
||||||
|
joiner: torch.jit.ScriptModule,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: Optional[torch.Tensor] = None,
|
||||||
|
hyp: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
assert encoder_out.ndim == 2
|
||||||
|
context_size = 2
|
||||||
|
blank_id = 0
|
||||||
|
|
||||||
|
if decoder_out is None:
|
||||||
|
assert hyp is None, hyp
|
||||||
|
hyp = [blank_id] * context_size
|
||||||
|
decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0)
|
||||||
|
# decoder_input.shape (1,, 1 context_size)
|
||||||
|
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||||
|
else:
|
||||||
|
assert decoder_out.ndim == 2
|
||||||
|
assert hyp is not None, hyp
|
||||||
|
|
||||||
|
T = encoder_out.size(0)
|
||||||
|
for i in range(T):
|
||||||
|
cur_encoder_out = encoder_out[i : i + 1]
|
||||||
|
joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||||
|
y = joiner_out.argmax(dim=0).item()
|
||||||
|
|
||||||
|
if y != blank_id:
|
||||||
|
hyp.append(y)
|
||||||
|
decoder_input = hyp[-context_size:]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0)
|
||||||
|
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||||
|
|
||||||
|
return hyp, decoder_out
|
||||||
|
|
||||||
|
|
||||||
|
def create_streaming_feature_extractor(sample_rate) -> OnlineFeature:
|
||||||
|
"""Create a CPU streaming feature extractor.
|
||||||
|
|
||||||
|
At present, we assume it returns a fbank feature extractor with
|
||||||
|
fixed options. In the future, we will support passing in the options
|
||||||
|
from outside.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a CPU streaming feature extractor.
|
||||||
|
"""
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = sample_rate
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
return OnlineFbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
Tokenizer.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
encoder = torch.jit.load(args.encoder_model_filename)
|
||||||
|
decoder = torch.jit.load(args.decoder_model_filename)
|
||||||
|
joiner = torch.jit.load(args.joiner_model_filename)
|
||||||
|
|
||||||
|
encoder.eval()
|
||||||
|
decoder.eval()
|
||||||
|
joiner.eval()
|
||||||
|
|
||||||
|
encoder.to(device)
|
||||||
|
decoder.to(device)
|
||||||
|
joiner.to(device)
|
||||||
|
|
||||||
|
sp = Tokenizer.load(args.lang, args.lang_type)
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
online_fbank = create_streaming_feature_extractor(args.sample_rate)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_file}")
|
||||||
|
wave_samples = read_sound_files(
|
||||||
|
filenames=[args.sound_file],
|
||||||
|
expected_sample_rate=args.sample_rate,
|
||||||
|
)[0]
|
||||||
|
logging.info(wave_samples.shape)
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
chunk_length = args.decode_chunk_len
|
||||||
|
assert encoder.decode_chunk_size == chunk_length // 2, (
|
||||||
|
encoder.decode_chunk_size,
|
||||||
|
chunk_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# we subsample features with ((x_len - 7) // 2 + 1) // 2
|
||||||
|
pad_length = 7
|
||||||
|
T = chunk_length + pad_length
|
||||||
|
|
||||||
|
logging.info(f"chunk_length: {chunk_length}")
|
||||||
|
|
||||||
|
states = encoder.get_init_state(device)
|
||||||
|
|
||||||
|
tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32)
|
||||||
|
|
||||||
|
wave_samples = torch.cat([wave_samples, tail_padding])
|
||||||
|
|
||||||
|
chunk = int(0.25 * args.sample_rate) # 0.2 second
|
||||||
|
num_processed_frames = 0
|
||||||
|
|
||||||
|
hyp = None
|
||||||
|
decoder_out = None
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
while start < wave_samples.numel():
|
||||||
|
logging.info(f"{start}/{wave_samples.numel()}")
|
||||||
|
end = min(start + chunk, wave_samples.numel())
|
||||||
|
samples = wave_samples[start:end]
|
||||||
|
start += chunk
|
||||||
|
online_fbank.accept_waveform(
|
||||||
|
sampling_rate=args.sample_rate,
|
||||||
|
waveform=samples,
|
||||||
|
)
|
||||||
|
while online_fbank.num_frames_ready - num_processed_frames >= T:
|
||||||
|
frames = []
|
||||||
|
for i in range(T):
|
||||||
|
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||||
|
frames = torch.cat(frames, dim=0).unsqueeze(0)
|
||||||
|
x_lens = torch.tensor([T], dtype=torch.int32)
|
||||||
|
encoder_out, out_lens, states = encoder(
|
||||||
|
x=frames,
|
||||||
|
x_lens=x_lens,
|
||||||
|
states=states,
|
||||||
|
)
|
||||||
|
num_processed_frames += chunk_length
|
||||||
|
|
||||||
|
hyp, decoder_out = greedy_search(
|
||||||
|
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp
|
||||||
|
)
|
||||||
|
|
||||||
|
context_size = 2
|
||||||
|
logging.info(args.sound_file)
|
||||||
|
logging.info(sp.decode(hyp[context_size:]))
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
torch._C._set_graph_executor_optimize(False)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py
|
347
egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py
Normal file
347
egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads a checkpoint and uses it to decode waves.
|
||||||
|
You can generate the checkpoint with the following command:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
Usage of this script:
|
||||||
|
|
||||||
|
(1) greedy search
|
||||||
|
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||||
|
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--method greedy_search \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(2) beam search
|
||||||
|
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||||
|
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--method beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||||
|
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--method modified_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
(4) fast beam search
|
||||||
|
./pruned_transducer_stateless7_streaming/pretrained.py \
|
||||||
|
--checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--method fast_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav
|
||||||
|
|
||||||
|
You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`.
|
||||||
|
|
||||||
|
Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by
|
||||||
|
./pruned_transducer_stateless7_streaming/export.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import kaldifeat
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from beam_search import (
|
||||||
|
beam_search,
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint. "
|
||||||
|
"The checkpoint is assumed to be saved by "
|
||||||
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_files",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="The input sound file(s) to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16000,
|
||||||
|
help="The sample rate of the input sound file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An integer indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --method is beam_search or
|
||||||
|
modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Used only when --method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="""Maximum number of symbols per frame. Used only when
|
||||||
|
--method is greedy_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
Tokenizer.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||||
|
|
||||||
|
# <blk> is defined in local/prepare_lang_char.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Creating model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
model.device = device
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
|
||||||
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
|
)
|
||||||
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
|
logging.info("Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
|
|
||||||
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||||
|
|
||||||
|
num_waves = encoder_out.size(0)
|
||||||
|
hyps = []
|
||||||
|
msg = f"Using {params.method}"
|
||||||
|
if params.method == "beam_search":
|
||||||
|
msg += f" with beam size {params.beam_size}"
|
||||||
|
logging.info(msg)
|
||||||
|
|
||||||
|
if params.method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.method == "modified_beam_search":
|
||||||
|
hyp_tokens = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
|
hyp_tokens = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
for hyp in sp.decode(hyp_tokens):
|
||||||
|
hyps.append(hyp.split())
|
||||||
|
else:
|
||||||
|
for i in range(num_waves):
|
||||||
|
# fmt: off
|
||||||
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
|
# fmt: on
|
||||||
|
if params.method == "greedy_search":
|
||||||
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
|
elif params.method == "beam_search":
|
||||||
|
hyp = beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
beam=params.beam_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
|
hyps.append(sp.decode(hyp).split())
|
||||||
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py
|
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
|
597
egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
597
egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
Executable file
@ -0,0 +1,597 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
./pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--decode-chunk-len 32 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||||
|
--decoding_method greedy_search \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--num-decode-streams 2000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import CSJAsrDataModule
|
||||||
|
from decode import save_results
|
||||||
|
from decode_stream import DecodeStream
|
||||||
|
from kaldifeat import Fbank, FbankOptions
|
||||||
|
from lhotse import CutSet
|
||||||
|
from streaming_beam_search import (
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from tokenizer import Tokenizer
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from zipformer import stack_states, unstack_states
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="pruned_transducer_stateless2/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Supported decoding methods are:
|
||||||
|
greedy_search
|
||||||
|
modified_beam_search
|
||||||
|
fast_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-graph",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_active_paths",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
|
frame. Used only when --decoding-method is modified_beam_search.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam",
|
||||||
|
type=float,
|
||||||
|
default=4.0,
|
||||||
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
`beam` in Kaldi.
|
||||||
|
Used only when --decoding-method is fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-contexts",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-states",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="""Used only when --decoding-method is
|
||||||
|
fast_beam_search""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-decode-streams",
|
||||||
|
type=int,
|
||||||
|
default=2000,
|
||||||
|
help="The number of streams that can be decoded parallel.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--res-dir",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="The path to save results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_chunk(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
decode_streams: List[DecodeStream],
|
||||||
|
) -> List[int]:
|
||||||
|
"""Decode one chunk frames of features for each decode_streams and
|
||||||
|
return the indexes of finished streams in a List.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
decode_streams:
|
||||||
|
A List of DecodeStream, each belonging to a utterance.
|
||||||
|
Returns:
|
||||||
|
Return a List containing which DecodeStreams are finished.
|
||||||
|
"""
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
features = []
|
||||||
|
feature_lens = []
|
||||||
|
states = []
|
||||||
|
processed_lens = []
|
||||||
|
|
||||||
|
for stream in decode_streams:
|
||||||
|
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
|
||||||
|
features.append(feat)
|
||||||
|
feature_lens.append(feat_len)
|
||||||
|
states.append(stream.states)
|
||||||
|
processed_lens.append(stream.done_frames)
|
||||||
|
|
||||||
|
feature_lens = torch.tensor(feature_lens, device=device)
|
||||||
|
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||||
|
|
||||||
|
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
|
||||||
|
# factor in encoders is 8.
|
||||||
|
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
|
||||||
|
tail_length = 23
|
||||||
|
if features.size(1) < tail_length:
|
||||||
|
pad_length = tail_length - features.size(1)
|
||||||
|
feature_lens += pad_length
|
||||||
|
features = torch.nn.functional.pad(
|
||||||
|
features,
|
||||||
|
(0, 0, 0, pad_length),
|
||||||
|
mode="constant",
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
states = stack_states(states)
|
||||||
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
|
||||||
|
x=features,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
states=states,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
|
fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
processed_lens=processed_lens,
|
||||||
|
streams=decode_streams,
|
||||||
|
beam=params.beam,
|
||||||
|
max_states=params.max_states,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
streams=decode_streams,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
num_active_paths=params.num_active_paths,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
|
|
||||||
|
states = unstack_states(new_states)
|
||||||
|
|
||||||
|
finished_streams = []
|
||||||
|
for i in range(len(decode_streams)):
|
||||||
|
decode_streams[i].states = states[i]
|
||||||
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
|
if decode_streams[i].done:
|
||||||
|
finished_streams.append(i)
|
||||||
|
|
||||||
|
return finished_streams
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
cuts: CutSet,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: Tokenizer,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cuts:
|
||||||
|
Lhotse Cutset containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
|
log_interval = 50
|
||||||
|
|
||||||
|
decode_results = []
|
||||||
|
# Contain decode streams currently running.
|
||||||
|
decode_streams = []
|
||||||
|
for num, cut in enumerate(cuts):
|
||||||
|
# each utterance has a DecodeStream.
|
||||||
|
initial_states = model.encoder.get_init_state(device=device)
|
||||||
|
decode_stream = DecodeStream(
|
||||||
|
params=params,
|
||||||
|
cut_id=cut.id,
|
||||||
|
initial_states=initial_states,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio: np.ndarray = cut.load_audio()
|
||||||
|
# audio.shape: (1, num_samples)
|
||||||
|
assert len(audio.shape) == 2
|
||||||
|
assert audio.shape[0] == 1, "Should be single channel"
|
||||||
|
assert audio.dtype == np.float32, audio.dtype
|
||||||
|
|
||||||
|
# The trained model is using normalized samples
|
||||||
|
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||||
|
|
||||||
|
samples = torch.from_numpy(audio).squeeze(0)
|
||||||
|
|
||||||
|
fbank = Fbank(opts)
|
||||||
|
feature = fbank(samples.to(device))
|
||||||
|
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
|
||||||
|
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
|
||||||
|
|
||||||
|
decode_streams.append(decode_stream)
|
||||||
|
|
||||||
|
while len(decode_streams) >= params.num_decode_streams:
|
||||||
|
finished_streams = decode_one_chunk(
|
||||||
|
params=params, model=model, decode_streams=decode_streams
|
||||||
|
)
|
||||||
|
for i in sorted(finished_streams, reverse=True):
|
||||||
|
decode_results.append(
|
||||||
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
|
sp.text2word(decode_streams[i].ground_truth),
|
||||||
|
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
del decode_streams[i]
|
||||||
|
|
||||||
|
if num % log_interval == 0:
|
||||||
|
logging.info(f"Cuts processed until now is {num}.")
|
||||||
|
|
||||||
|
# decode final chunks of last sequences
|
||||||
|
while len(decode_streams):
|
||||||
|
finished_streams = decode_one_chunk(
|
||||||
|
params=params, model=model, decode_streams=decode_streams
|
||||||
|
)
|
||||||
|
for i in sorted(finished_streams, reverse=True):
|
||||||
|
decode_results.append(
|
||||||
|
(
|
||||||
|
decode_streams[i].id,
|
||||||
|
sp.text2word(decode_streams[i].ground_truth),
|
||||||
|
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
del decode_streams[i]
|
||||||
|
|
||||||
|
if params.decoding_method == "greedy_search":
|
||||||
|
key = "greedy_search"
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
key = (
|
||||||
|
f"beam_{params.beam}_"
|
||||||
|
f"max_contexts_{params.max_contexts}_"
|
||||||
|
f"max_states_{params.max_states}"
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
|
return {key: decode_results}
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
CSJAsrDataModule.add_arguments(parser)
|
||||||
|
Tokenizer.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
if not params.res_dir:
|
||||||
|
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
# for streaming
|
||||||
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||||
|
|
||||||
|
# for fast_beam_search
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
params.suffix += f"-beam-{params.beam}"
|
||||||
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", params.gpu)
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||||
|
|
||||||
|
# <blk> and <unk> is defined in local/prepare_lang_char.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
model.device = device
|
||||||
|
|
||||||
|
decoding_graph = None
|
||||||
|
if params.decoding_graph:
|
||||||
|
decoding_graph = k2.Fsa.from_dict(
|
||||||
|
torch.load(params.decoding_graph, map_location=device)
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
args.return_cuts = True
|
||||||
|
csj_corpus = CSJAsrDataModule(args)
|
||||||
|
|
||||||
|
for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]:
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
cuts=getattr(csj_corpus, f"{subdir}_cuts")(),
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
)
|
||||||
|
tot_err = save_results(
|
||||||
|
params=params, test_set_name=subdir, results_dict=results_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
params.res_dir
|
||||||
|
/ (
|
||||||
|
f"{subdir}-{params.decode_chunk_len}"
|
||||||
|
f"_{params.avg}_{params.epoch}.cer"
|
||||||
|
)
|
||||||
|
).open("w") as fout:
|
||||||
|
if len(tot_err) == 1:
|
||||||
|
fout.write(f"{tot_err[0][1]}")
|
||||||
|
else:
|
||||||
|
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
150
egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py
Executable file
150
egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py
Executable file
@ -0,0 +1,150 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/csj/ASR
|
||||||
|
python ./pruned_transducer_stateless7_streaming/test_model.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.num_encoder_layers = "2,4,3,2,4"
|
||||||
|
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||||
|
params.nhead = "8,8,8,8,8"
|
||||||
|
params.encoder_dims = "384,384,384,384,384"
|
||||||
|
params.attention_dims = "192,192,192,192,192"
|
||||||
|
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||||
|
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||||
|
params.cnn_module_kernels = "31,31,31,31,31"
|
||||||
|
params.decoder_dim = 512
|
||||||
|
params.joiner_dim = 512
|
||||||
|
params.num_left_chunks = 4
|
||||||
|
params.short_chunk_size = 50
|
||||||
|
params.decode_chunk_len = 32
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# Test jit script
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
|
# it here.
|
||||||
|
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||||
|
# torch scriptabe.
|
||||||
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
print("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_jit_trace():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.num_encoder_layers = "2,4,3,2,4"
|
||||||
|
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||||
|
params.nhead = "8,8,8,8,8"
|
||||||
|
params.encoder_dims = "384,384,384,384,384"
|
||||||
|
params.attention_dims = "192,192,192,192,192"
|
||||||
|
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||||
|
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||||
|
params.cnn_module_kernels = "31,31,31,31,31"
|
||||||
|
params.decoder_dim = 512
|
||||||
|
params.joiner_dim = 512
|
||||||
|
params.num_left_chunks = 4
|
||||||
|
params.short_chunk_size = 50
|
||||||
|
params.decode_chunk_len = 32
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
# Test encoder
|
||||||
|
def _test_encoder():
|
||||||
|
encoder = model.encoder
|
||||||
|
assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||||
|
encoder.decode_chunk_size,
|
||||||
|
params.decode_chunk_len,
|
||||||
|
)
|
||||||
|
T = params.decode_chunk_len + 7
|
||||||
|
|
||||||
|
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||||
|
states = encoder.get_init_state(device=x.device)
|
||||||
|
encoder.__class__.forward = encoder.__class__.streaming_forward
|
||||||
|
traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
|
||||||
|
|
||||||
|
states1 = encoder.get_init_state(device=x.device)
|
||||||
|
states2 = traced_encoder.get_init_state(device=x.device)
|
||||||
|
for i in range(5):
|
||||||
|
x = torch.randn(1, T, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||||
|
y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
|
||||||
|
y2, _, states2 = traced_encoder(x, x_lens, states2)
|
||||||
|
assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
|
||||||
|
|
||||||
|
# Test decoder
|
||||||
|
def _test_decoder():
|
||||||
|
decoder = model.decoder
|
||||||
|
y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
|
||||||
|
need_pad = torch.tensor([False])
|
||||||
|
|
||||||
|
traced_decoder = torch.jit.trace(decoder, (y, need_pad))
|
||||||
|
d1 = decoder(y, need_pad)
|
||||||
|
d2 = traced_decoder(y, need_pad)
|
||||||
|
assert torch.equal(d1, d2), (d1 - d2).abs().mean()
|
||||||
|
|
||||||
|
# Test joiner
|
||||||
|
def _test_joiner():
|
||||||
|
joiner = model.joiner
|
||||||
|
encoder_out_dim = joiner.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
|
||||||
|
j1 = joiner(encoder_out, decoder_out)
|
||||||
|
j2 = traced_joiner(encoder_out, decoder_out)
|
||||||
|
assert torch.equal(j1, j2), (j1 - j2).abs().mean()
|
||||||
|
|
||||||
|
_test_encoder()
|
||||||
|
_test_decoder()
|
||||||
|
_test_joiner()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_model()
|
||||||
|
test_model_jit_trace()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../local/utils/tokenizer.py
|
1304
egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
1304
egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py
Symbolic link
1
egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
|
Loading…
x
Reference in New Issue
Block a user