From e63a8c27f811bbee321429f8253ff8d1260aa929 Mon Sep 17 00:00:00 2001 From: Teo Wen Shen <36886809+teowenshen@users.noreply.github.com> Date: Mon, 13 Feb 2023 23:19:50 +0900 Subject: [PATCH] CSJ pruned_transducer_stateless7_streaming (#892) * update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * update manifest stats * update transcript configs * lang_char and compute_fbanks * save cuts in fbank_dir * add core codes * update decode.py * Create local/utils * tidy up * parse raw in prepare_lang_char.py * working train * Add compare_cer_transcript.py * fix tokenizer decode, allow d2f only * comment cleanup * add export files and READMEs * reword average column * fix comments * Update new results --- egs/csj/ASR/README.md | 11 + egs/csj/ASR/RESULTS.md | 200 +++ egs/csj/ASR/local/add_transcript_mode.py | 94 ++ egs/csj/ASR/local/compute_fbank_csj.py | 109 +- egs/csj/ASR/local/compute_fbank_musan.py | 12 +- egs/csj/ASR/local/conf/disfluent.ini | 243 +-- egs/csj/ASR/local/conf/fluent.ini | 243 +-- egs/csj/ASR/local/conf/number.ini | 241 --- egs/csj/ASR/local/conf/symbol.ini | 251 +--- .../ASR/local/disfluent_recogs_to_fluent.py | 202 +++ .../ASR/local/display_manifest_statistics.py | 376 +++-- egs/csj/ASR/local/prepare_lang_char.py | 102 +- egs/csj/ASR/local/utils/asr_datamodule.py | 462 ++++++ egs/csj/ASR/local/utils/tokenizer.py | 253 ++++ egs/csj/ASR/prepare.sh | 42 +- .../TelegramStreamIO.py | 76 + .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../decode.py | 852 +++++++++++ .../decode_stream.py | 1 + .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export.py | 313 ++++ .../jit_trace_export.py | 308 ++++ .../jit_trace_pretrained.py | 286 ++++ .../joiner.py | 1 + .../model.py | 1 + .../optim.py | 1 + .../pretrained.py | 347 +++++ .../scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 597 ++++++++ .../test_model.py | 150 ++ .../tokenizer.py | 1 + .../train.py | 1304 +++++++++++++++++ .../zipformer.py | 1 + 37 files changed, 5847 insertions(+), 1240 deletions(-) create mode 100644 egs/csj/ASR/README.md create mode 100644 egs/csj/ASR/RESULTS.md create mode 100644 egs/csj/ASR/local/add_transcript_mode.py create mode 100644 egs/csj/ASR/local/disfluent_recogs_to_fluent.py create mode 100644 egs/csj/ASR/local/utils/asr_datamodule.py create mode 100644 egs/csj/ASR/local/utils/tokenizer.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py create mode 100644 egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py create mode 100755 egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py create mode 120000 egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py diff --git a/egs/csj/ASR/README.md b/egs/csj/ASR/README.md new file mode 100644 index 000000000..95c2ec6ac --- /dev/null +++ b/egs/csj/ASR/README.md @@ -0,0 +1,11 @@ +# Introduction + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +These are the types of architectures currently available. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|---------------------------------------------------| +| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | Adapted from librispeech pruned_transducer_stateless7_streaming | diff --git a/egs/csj/ASR/RESULTS.md b/egs/csj/ASR/RESULTS.md new file mode 100644 index 000000000..56fdb899f --- /dev/null +++ b/egs/csj/ASR/RESULTS.md @@ -0,0 +1,200 @@ +# Results + +## Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer) + +### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming) + +See for more details. + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +Number of model parameters: 75688409, i.e. 75.7M. + +#### training on disfluent transcript + +The CERs are: + +| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | +| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | +| fast beam search | 320ms | 5.39 | 4.08 | 4.16 | 5.4 | 5.02 | --epoch 30 --avg 17 | simulated streaming | +| fast beam search | 320ms | 5.34 | 4.1 | 4.26 | 5.61 | 4.91 | --epoch 30 --avg 17 | chunk-wise | +| greedy search | 320ms | 5.43 | 4.14 | 4.31 | 5.48 | 4.88 | --epoch 30 --avg 17 | simulated streaming | +| greedy search | 320ms | 5.44 | 4.14 | 4.39 | 5.7 | 4.98 | --epoch 30 --avg 17 | chunk-wise | +| modified beam search | 320ms | 5.2 | 3.95 | 4.09 | 5.12 | 4.75 | --epoch 30 --avg 17 | simulated streaming | +| modified beam search | 320ms | 5.18 | 4.07 | 4.12 | 5.36 | 4.77 | --epoch 30 --avg 17 | chunk-wise | +| fast beam search | 640ms | 5.01 | 3.78 | 3.96 | 4.85 | 4.6 | --epoch 30 --avg 17 | simulated streaming | +| fast beam search | 640ms | 4.97 | 3.88 | 3.96 | 4.91 | 4.61 | --epoch 30 --avg 17 | chunk-wise | +| greedy search | 640ms | 5.02 | 3.84 | 4.14 | 5.02 | 4.59 | --epoch 30 --avg 17 | simulated streaming | +| greedy search | 640ms | 5.32 | 4.22 | 4.33 | 5.39 | 4.99 | --epoch 30 --avg 17 | chunk-wise | +| modified beam search | 640ms | 4.78 | 3.66 | 3.85 | 4.72 | 4.42 | --epoch 30 --avg 17 | simulated streaming | +| modified beam search | 640ms | 5.77 | 4.72 | 4.73 | 5.85 | 5.36 | --epoch 30 --avg 17 | chunk-wise | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command was: +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --max-duration 375 \ + --transcript-mode disfluent \ + --lang data/lang_char \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --pad-feature 30 \ + --musan-dir /mnt/host/corpus/musan/musan/fbank +``` + +The simulated streaming decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --epoch 30 \ + --avg 17 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode disfluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/sim_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --pad-feature 30 \ + --gpu 0 + done +done +``` + +The streaming chunk-wise decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/streaming_decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30 \ + --epoch 30 \ + --avg 17 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode disfluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_disfluent_2_pad30/github/stream_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --gpu 2 \ + --num-decode-streams 40 + done +done +``` + +#### training on fluent transcript + +The CERs are: + +| decoding method | chunk size | eval1 | eval2 | eval3 | excluded | valid | average | decoding mode | +| --------------- | ---------- | ----- | ----- | ----- | -------- | ----- | ------- | ------------- | +| fast beam search | 320ms | 4.19 | 3.63 | 3.77 | 4.43 | 4.09 | --epoch 30 --avg 12 | simulated streaming | +| fast beam search | 320ms | 4.06 | 3.55 | 3.66 | 4.70 | 4.04 | --epoch 30 --avg 12 | chunk-wise | +| greedy search | 320ms | 4.22 | 3.62 | 3.82 | 4.45 | 3.98 | --epoch 30 --avg 12 | simulated streaming | +| greedy search | 320ms | 4.13 | 3.61 | 3.85 | 4.67 | 4.05 | --epoch 30 --avg 12 | chunk-wise | +| modified beam search | 320ms | 4.02 | 3.43 | 3.62 | 4.43 | 3.81 | --epoch 30 --avg 12 | simulated streaming | +| modified beam search | 320ms | 3.97 | 3.43 | 3.59 | 4.99 | 3.88 | --epoch 30 --avg 12 | chunk-wise | +| fast beam search | 640ms | 3.80 | 3.31 | 3.55 | 4.16 | 3.90 | --epoch 30 --avg 12 | simulated streaming | +| fast beam search | 640ms | 3.81 | 3.34 | 3.46 | 4.58 | 3.85 | --epoch 30 --avg 12 | chunk-wise | +| greedy search | 640ms | 3.92 | 3.38 | 3.65 | 4.31 | 3.88 | --epoch 30 --avg 12 | simulated streaming | +| greedy search | 640ms | 3.98 | 3.38 | 3.64 | 4.54 | 4.01 | --epoch 30 --avg 12 | chunk-wise | +| modified beam search | 640ms | 3.72 | 3.26 | 3.39 | 4.10 | 3.65 | --epoch 30 --avg 12 | simulated streaming | +| modified beam search | 640ms | 3.78 | 3.32 | 3.45 | 4.81 | 3.81 | --epoch 30 --avg 12 | chunk-wise | + +Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`, +while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`. + +The training command was: +```bash +./pruned_transducer_stateless7_streaming/train.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --max-duration 375 \ + --transcript-mode fluent \ + --lang data/lang_char \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --pad-feature 30 \ + --musan-dir /mnt/host/corpus/musan/musan/fbank +``` + +The simulated streaming decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --epoch 30 \ + --avg 12 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode fluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/sim_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --pad-feature 30 \ + --gpu 1 + done +done +``` + +The streaming chunk-wise decoding command was: +```bash +for chunk in 64 32; do + for m in greedy_search fast_beam_search modified_beam_search; do + python pruned_transducer_stateless7_streaming/streaming_decode.py \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --exp-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30 \ + --epoch 30 \ + --avg 12 \ + --max-duration 350 \ + --decoding-method $m \ + --manifest-dir /mnt/host/corpus/csj/fbank \ + --lang data/lang_char \ + --transcript-mode fluent \ + --res-dir pruned_transducer_stateless7_streaming/exp_fluent_2_pad30/github/stream_"$chunk"_"$m" \ + --decode-chunk-len $chunk \ + --gpu 3 \ + --num-decode-streams 40 + done +done +``` + +#### Comparing disfluent to fluent + +$$ \texttt{CER}^{f}_d = \frac{\texttt{sub}_f + \texttt{ins} + \texttt{del}_f}{N_f} $$ + +This comparison evaluates the disfluent model on the fluent transcript (calculated by `disfluent_recogs_to_fluent.py`), forgiving the disfluent model's mistakes on fillers and partial words. It is meant as an illustrative metric only, so that the disfluent and fluent models can be compared. + +| decoding method | chunk size | eval1 (d vs f) | eval2 (d vs f) | eval3 (d vs f) | excluded (d vs f) | valid (d vs f) | decoding mode | +| --------------- | ---------- | -------------- | --------------- | -------------- | -------------------- | --------------- | ----------- | +| fast beam search | 320ms | 4.54 vs 4.19 | 3.44 vs 3.63 | 3.56 vs 3.77 | 4.22 vs 4.43 | 4.22 vs 4.09 | simulated streaming | +| fast beam search | 320ms | 4.48 vs 4.06 | 3.41 vs 3.55 | 3.65 vs 3.66 | 4.26 vs 4.7 | 4.08 vs 4.04 | chunk-wise | +| greedy search | 320ms | 4.53 vs 4.22 | 3.48 vs 3.62 | 3.69 vs 3.82 | 4.38 vs 4.45 | 4.05 vs 3.98 | simulated streaming | +| greedy search | 320ms | 4.53 vs 4.13 | 3.46 vs 3.61 | 3.71 vs 3.85 | 4.48 vs 4.67 | 4.12 vs 4.05 | chunk-wise | +| modified beam search | 320ms | 4.45 vs 4.02 | 3.38 vs 3.43 | 3.57 vs 3.62 | 4.19 vs 4.43 | 4.04 vs 3.81 | simulated streaming | +| modified beam search | 320ms | 4.44 vs 3.97 | 3.47 vs 3.43 | 3.56 vs 3.59 | 4.28 vs 4.99 | 4.04 vs 3.88 | chunk-wise | +| fast beam search | 640ms | 4.14 vs 3.8 | 3.12 vs 3.31 | 3.38 vs 3.55 | 3.72 vs 4.16 | 3.81 vs 3.9 | simulated streaming | +| fast beam search | 640ms | 4.05 vs 3.81 | 3.23 vs 3.34 | 3.36 vs 3.46 | 3.65 vs 4.58 | 3.78 vs 3.85 | chunk-wise | +| greedy search | 640ms | 4.1 vs 3.92 | 3.17 vs 3.38 | 3.5 vs 3.65 | 3.87 vs 4.31 | 3.77 vs 3.88 | simulated streaming | +| greedy search | 640ms | 4.41 vs 3.98 | 3.56 vs 3.38 | 3.69 vs 3.64 | 4.26 vs 4.54 | 4.16 vs 4.01 | chunk-wise | +| modified beam search | 640ms | 4 vs 3.72 | 3.08 vs 3.26 | 3.33 vs 3.39 | 3.75 vs 4.1 | 3.71 vs 3.65 | simulated streaming | +| modified beam search | 640ms | 5.05 vs 3.78 | 4.22 vs 3.32 | 4.26 vs 3.45 | 5.02 vs 4.81 | 4.73 vs 3.81 | chunk-wise | +| average (d - f) | | 0.43 | -0.02 | -0.02 | -0.34 | 0.13 | | diff --git a/egs/csj/ASR/local/add_transcript_mode.py b/egs/csj/ASR/local/add_transcript_mode.py new file mode 100644 index 000000000..f6b4b2caf --- /dev/null +++ b/egs/csj/ASR/local/add_transcript_mode.py @@ -0,0 +1,94 @@ +import argparse +import logging +from configparser import ConfigParser +from pathlib import Path +from typing import List + +from lhotse import CutSet, SupervisionSet +from lhotse.recipes.csj import CSJSDBParser + +ARGPARSE_DESCRIPTION = """ +This script adds transcript modes to an existing CutSet or SupervisionSet. +""" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=ARGPARSE_DESCRIPTION, + ) + parser.add_argument( + "-f", + "--fbank-dir", + type=Path, + help="Path to directory where manifests are stored.", + ) + parser.add_argument( + "-c", + "--config", + type=Path, + nargs="+", + help="Path to config file for transcript parsing.", + ) + return parser.parse_args() + + +def get_CSJParsers(config_files: List[Path]) -> List[CSJSDBParser]: + parsers = [] + for config_file in config_files: + config = ConfigParser() + config.optionxform = str + assert config.read(config_file), f"{config_file} could not be found." + decisions = {} + for k, v in config["DECISIONS"].items(): + try: + decisions[k] = int(v) + except ValueError: + decisions[k] = v + parsers.append( + (config["CONSTANTS"].get("MODE"), CSJSDBParser(decisions=decisions)) + ) + return parsers + + +def main(): + args = get_args() + logging.basicConfig( + format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), + level=logging.INFO, + ) + parsers = get_CSJParsers(args.config) + config = ConfigParser() + config.optionxform = str + assert config.read(args.config), args.config + decisions = {} + for k, v in config["DECISIONS"].items(): + try: + decisions[k] = int(v) + except ValueError: + decisions[k] = v + + logging.info(f"Adding {', '.join(x[0] for x in parsers)} transcript mode.") + + manifests = args.fbank_dir.glob("csj_cuts_*.jsonl.gz") + assert manifests, f"No cuts to be found in {args.fbank_dir}" + + for manifest in manifests: + results = [] + logging.info(f"Adding transcript modes to {manifest.name} now.") + cutset = CutSet.from_file(manifest) + for cut in cutset: + for name, parser in parsers: + cut.supervisions[0].custom[name] = parser.parse( + cut.supervisions[0].custom["raw"] + ) + cut.supervisions[0].text = "" + results.append(cut) + results = CutSet.from_items(results) + res_file = manifest.as_posix() + manifest.replace(manifest.parent / ("bak." + manifest.name)) + results.to_file(res_file) + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/compute_fbank_csj.py b/egs/csj/ASR/local/compute_fbank_csj.py index 667ad427e..ce560025d 100644 --- a/egs/csj/ASR/local/compute_fbank_csj.py +++ b/egs/csj/ASR/local/compute_fbank_csj.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa +# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,9 +19,7 @@ import argparse import logging import os -from itertools import islice from pathlib import Path -from random import Random from typing import List, Tuple import torch @@ -35,20 +33,10 @@ from lhotse import ( # See the following for why LilcomChunkyWriter is preferre RecordingSet, SupervisionSet, ) +from lhotse.recipes.csj import concat_csj_supervisions # fmt: on -ARGPARSE_DESCRIPTION = """ -This script follows the espnet method of splitting the remaining core+noncore -utterances into valid and train cutsets at an index which is by default 4000. - -In other words, the core+noncore utterances are shuffled, where 4000 utterances -of the shuffled set go to the `valid` cutset and are not subject to speed -perturbation. The remaining utterances become the `train` cutset and are speed- -perturbed (0.9x, 1.0x, 1.1x). - -""" - # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. # Do this outside of main() in case it needs to take effect @@ -57,66 +45,101 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) RNG_SEED = 42 +# concat_params_train = [ +# {"gap": 1.0, "maxlen": 10.0}, +# {"gap": 1.5, "maxlen": 8.0}, +# {"gap": 1.0, "maxlen": 18.0}, +# ] + +concat_params = {"gap": 1.0, "maxlen": 10.0} def make_cutset_blueprints( manifest_dir: Path, - split: int, ) -> List[Tuple[str, CutSet]]: cut_sets = [] + logging.info("Creating non-train cuts.") + # Create eval datasets - logging.info("Creating eval cuts.") for i in range(1, 4): + sps = sorted( + SupervisionSet.from_file( + manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" + ), + key=lambda x: x.id, + ) + cut_set = CutSet.from_manifests( recordings=RecordingSet.from_file( manifest_dir / f"csj_recordings_eval{i}.jsonl.gz" ), - supervisions=SupervisionSet.from_file( - manifest_dir / f"csj_supervisions_eval{i}.jsonl.gz" - ), + supervisions=concat_csj_supervisions(sps, **concat_params), ) cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) cut_sets.append((f"eval{i}", cut_set)) - # Create train and valid cuts - logging.info("Loading, trimming, and shuffling the remaining core+noncore cuts.") - recording_set = RecordingSet.from_file( - manifest_dir / "csj_recordings_core.jsonl.gz" - ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") - supervision_set = SupervisionSet.from_file( - manifest_dir / "csj_supervisions_core.jsonl.gz" - ) + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz") - + # Create excluded dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_excluded.jsonl.gz"), + key=lambda x: x.id, + ) cut_set = CutSet.from_manifests( - recordings=recording_set, - supervisions=supervision_set, + recordings=RecordingSet.from_file( + manifest_dir / "csj_recordings_excluded.jsonl.gz" + ), + supervisions=concat_csj_supervisions(sps, **concat_params), ) cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) - cut_set = cut_set.shuffle(Random(RNG_SEED)) + cut_sets.append(("excluded", cut_set)) - logging.info( - "Creating valid and train cuts from core and noncore, split at {split}." + # Create valid dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_valid.jsonl.gz"), + key=lambda x: x.id, ) - valid_set = CutSet.from_cuts(islice(cut_set, 0, split)) + cut_set = CutSet.from_manifests( + recordings=RecordingSet.from_file( + manifest_dir / "csj_recordings_valid.jsonl.gz" + ), + supervisions=concat_csj_supervisions(sps, **concat_params), + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_sets.append(("valid", cut_set)) - train_set = CutSet.from_cuts(islice(cut_set, split, None)) + logging.info("Creating train cuts.") + + # Create train dataset + sps = sorted( + SupervisionSet.from_file(manifest_dir / "csj_supervisions_core.jsonl.gz") + + SupervisionSet.from_file(manifest_dir / "csj_supervisions_noncore.jsonl.gz"), + key=lambda x: x.id, + ) + + recording = RecordingSet.from_file( + manifest_dir / "csj_recordings_core.jsonl.gz" + ) + RecordingSet.from_file(manifest_dir / "csj_recordings_noncore.jsonl.gz") + + train_set = CutSet.from_manifests( + recordings=recording, supervisions=concat_csj_supervisions(sps, **concat_params) + ).trim_to_supervisions(keep_overlapping=False) train_set = train_set + train_set.perturb_speed(0.9) + train_set.perturb_speed(1.1) - cut_sets.extend([("valid", valid_set), ("train", train_set)]) + cut_sets.append(("train", train_set)) return cut_sets def get_args(): parser = argparse.ArgumentParser( - description=ARGPARSE_DESCRIPTION, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - - parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") - parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") - parser.add_argument("--split", type=int, default=4000, help="Split at this index") + parser.add_argument( + "-m", "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "-f", "--fbank-dir", type=Path, help="Path to save fbank features" + ) return parser.parse_args() @@ -138,7 +161,7 @@ def main(): ) return else: - cut_sets = make_cutset_blueprints(args.manifest_dir, args.split) + cut_sets = make_cutset_blueprints(args.manifest_dir) for part, cut_set in cut_sets: logging.info(f"Processing {part}") cut_set = cut_set.compute_and_store_features( @@ -147,7 +170,7 @@ def main(): storage_path=(args.fbank_dir / f"feats_{part}").as_posix(), storage_type=LilcomChunkyWriter, ) - cut_set.to_file(args.manifest_dir / f"csj_cuts_{part}.jsonl.gz") + cut_set.to_file(args.fbank_dir / f"csj_cuts_{part}.jsonl.gz") logging.info("All fbank computed for CSJ.") (args.fbank_dir / ".done").touch() diff --git a/egs/csj/ASR/local/compute_fbank_musan.py b/egs/csj/ASR/local/compute_fbank_musan.py index f60e62c85..c942df98e 100644 --- a/egs/csj/ASR/local/compute_fbank_musan.py +++ b/egs/csj/ASR/local/compute_fbank_musan.py @@ -28,9 +28,7 @@ from icefall.utils import get_executor ARGPARSE_DESCRIPTION = """ This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. -The generated fbank features are saved in data/fbank. """ # Torch's multithreaded behavior needs to be disabled or @@ -42,8 +40,6 @@ torch.set_num_interop_threads(1) def compute_fbank_musan(manifest_dir: Path, fbank_dir: Path): - # src_dir = Path("data/manifests") - # output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) num_mel_bins = 80 @@ -104,8 +100,12 @@ def get_args(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to save manifests") - parser.add_argument("--fbank-dir", type=Path, help="Path to save fbank features") + parser.add_argument( + "-m", "--manifest-dir", type=Path, help="Path to save manifests" + ) + parser.add_argument( + "-f", "--fbank-dir", type=Path, help="Path to save fbank features" + ) return parser.parse_args() diff --git a/egs/csj/ASR/local/conf/disfluent.ini b/egs/csj/ASR/local/conf/disfluent.ini index c987e72c5..4f0a9ec0e 100644 --- a/egs/csj/ASR/local/conf/disfluent.ini +++ b/egs/csj/ASR/local/conf/disfluent.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = disfluent -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 0 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 0 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 0 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 0 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 0 -; # Example: '(X (D2 ノ))' -D2^ = 0 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 0 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index f7f27f5bc..5d033ed17 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = fluent -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 1 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 1 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 1 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 1 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 1 -; # Example: '(X (D2 ノ))' -D2^ = 1 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 0 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index cf9038f62..3ada9aa24 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,320 +1,79 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode MODE = number -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' F = 1 -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = 1 ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' D = 1 -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = 1 ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' D2 = 1 -; # Example: '(X (D2 ノ))' -D2^ = 1 ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) A_num = 1 -A_num^ = 1 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index f9801284b..dafd65c9a 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,321 +1,80 @@ -; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj -[SEGMENTS] -; # Allowed period of nonverbal noise. If exceeded, a new segment is created. -gap = 0.5 -; # Maximum length of segment (s). -maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. -minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = - [CONSTANTS] ; # Name of this mode -; # See https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf +; # From https://www.isca-speech.org/archive/pdfs/interspeech_2022/horii22_interspeech.pdf MODE = symbol -; # Suffixes to use after the word surface (no longer used) -MORPH = pos1 cForm cType2 pos2 -; # Used to differentiate between A tag and A_num tag -JPN_NUM = ゼロ 0 零 一 二 三 四 五 六 七 八 九 十 百 千 . -; # Dummy character to delineate multiline words -PLUS = + [DECISIONS] -; # TAG+'^'とは、タグが一つの転記単位に独立していない場合 -; # The PLUS (fullwidth) sign '+' marks line boundaries for multiline entries - ; # フィラー、感情表出系感動詞 ; # 0 to remain, 1 to delete ; # Example: '(F ぎょっ)' -F = # -; # Example: '(L (F ン))', '比べ(F えー)る' -F^ = # +F = "#", ["F"] ; # 言い直し、いいよどみなどによる語断片 ; # 0 to remain, 1 to delete ; # Example: '(D だ)(D だいが) 大学の学部の会議' -D = @ -; # Example: '(L (D ドゥ)+(D ヒ))' -D^ = @ +D = "@", ["D"] ; # 助詞、助動詞、接辞の言い直し ; # 0 to remain, 1 to delete ; # Example: '西洋 (D2 的)(F えー)(D ふ) 風というか' -D2 = @ -; # Example: '(X (D2 ノ))' -D2^ = @ +D2 = "@", ["D2"] ; # 聞き取りや語彙の判断に自信がない場合 ; # 0 to remain, 1 to delete ; # Example: (? 字数) の ; # If no option: empty string is returned regardless of output ; # Example: '(?) で' ? = 0 -; # Example: '(D (? すー))+そう+です+よ+ね' -?^ = 0 ; # タグ?で、値は複数の候補が想定される場合 ; # 0 for main guess with matching morph info, 1 for second guess ; # Example: '(? 次数, 実数)', '(? これ,ここで)+(? 説明+し+た+方+が+いい+か+な)' ?, = 0 -; # Example: '(W (? テユクー);(? ケッキョク,テユウコトデ))', '(W マシ;(? マシ+タ,マス))' -?,^ = 0 ; # 音や言葉に関するメタ的な引用 ; # 0 to remain, 1 to delete ; # Example: '助詞の (M は) は (M は) と書くが発音は (M わ)' M = 0 -; # Example: '(L (M ヒ)+(M ヒ))', '(L (M (? ヒ+ヒ)))' -M^ = 0 ; # 外国語や古語、方言など ; # 0 to remain, 1 to delete ; # Example: '(O ザッツファイン)' O = 0 -; # Example: '(笑 (O エクスキューズ+ミー))', '(笑 メダッ+テ+(O ナンボ))' -O^ = 0 ; # 講演者の名前、差別語、誹謗中傷など ; # 0 to remain, 1 to delete ; # Example: '国語研の (R ××) です' R = 0 -R^ = 0 ; # 非朗読対象発話(朗読における言い間違い等) ; # 0 to remain, 1 to delete ; # Example: '(X 実際は) 実際には' X = 0 -; # Example: '(L (X (D2 ニ)))' -X^ = 0 ; # アルファベットや算用数字、記号の表記 ; # 0 to use Japanese form, 1 to use alphabet form ; # Example: '(A シーディーアール;CD-R)' A = 1 -; # Example: 'スモール(A エヌ;N)', 'ラージ(A キュー;Q)', '(A ティーエフ;TF)+(A アイディーエフ;IDF)' (Strung together by pron: '(W (? ティーワイド);ティーエフ+アイディーエフ)') -A^ = 1 ; # タグAで、単語は算用数字の場合 ; # 0 to use Japanese form, 1 to use Arabic numerals ; # Example: (A 二千;2000) -A_num = eval:self.notag -A_num^ = eval:self.notag +A_num = 1 ; # 何らかの原因で漢字表記できなくなった場合 ; # 0 to use broken form, 1 to use orthodox form ; # Example: '(K たち (F えー) ばな;橘)' K = 1 -; # Example: '合(K か(?)く;格)', '宮(K ま(?)え;前)' -K^ = 1 ; # 転訛、発音の怠けなど、一時的な発音エラー ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(W ギーツ;ギジュツ)' W = 1 -; # Example: '(F (W エド;エト))', 'イベント(W リレーティッド;リレーテッド)' -W^ = 1 ; # 語の読みに関する知識レベルのいい間違い ; # 0 to use wrong form, 1 to use orthodox form ; # Example: '(B シブタイ;ジュータイ)' B = 0 -; # Example: 'データー(B カズ;スー)' -B^ = 0 ; # 笑いながら発話 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', -笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete ; # Example: '(泣 ドンナニ)' 泣 = 0 -泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete ; # Example: 'シャ(咳 リン) ノ' 咳 = 0 -; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' -咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete ; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 -; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' -L^ = 0 - -[REPLACEMENTS] -; # ボーカルフライなどで母音が同定できない場合 - = -; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = -; # 非語彙的な母音の引き延ばし - = -; # 非語彙的な子音の引き延ばし - = -; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = -; # 言語音と独立に講演者の咳が生じている場合 -<咳> = -; # 言語音と独立に講演者の息が生じている場合 -<息> = -; # 講演者の泣き声 -<泣> = -; # 聴衆(司会者なども含む)の発話 -<フロア発話> = -; # 聴衆の笑い -<フロア笑> = -; # 聴衆の拍手 -<拍手> = -; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = -; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = -; # 転記単位全体が再度読み直された場合 -<朗読間違い> = -; # 上記以外の音で特に目立った音 -<雑音> = -; # 0.2秒以上のポーズ -

= -; # Redacted information, for R -; # It is \x00D7 multiplication sign, not your normal 'x' -× = × - -[FIELDS] -; # Time information for segment -time = 3 -; # Word surface -surface = 5 -; # Word surface root form without CSJ tags -notag = 9 -; # Part Of Speech -pos1 = 11 -; # Conjugated Form -cForm = 12 -; # Conjugation Type -cType1 = 13 -; # Subcategory of POS -pos2 = 14 -; # Euphonic Change / Subcategory of Conjugation Type -cType2 = 15 -; # Other information -other = 16 -; # Pronunciation for lexicon -pron = 10 -; # Speaker ID -spk_id = 2 - -[KATAKANA2ROMAJI] -ア = 'a -イ = 'i -ウ = 'u -エ = 'e -オ = 'o -カ = ka -キ = ki -ク = ku -ケ = ke -コ = ko -ガ = ga -ギ = gi -グ = gu -ゲ = ge -ゴ = go -サ = sa -シ = si -ス = su -セ = se -ソ = so -ザ = za -ジ = zi -ズ = zu -ゼ = ze -ゾ = zo -タ = ta -チ = ti -ツ = tu -テ = te -ト = to -ダ = da -ヂ = di -ヅ = du -デ = de -ド = do -ナ = na -ニ = ni -ヌ = nu -ネ = ne -ノ = no -ハ = ha -ヒ = hi -フ = hu -ヘ = he -ホ = ho -バ = ba -ビ = bi -ブ = bu -ベ = be -ボ = bo -パ = pa -ピ = pi -プ = pu -ペ = pe -ポ = po -マ = ma -ミ = mi -ム = mu -メ = me -モ = mo -ヤ = ya -ユ = yu -ヨ = yo -ラ = ra -リ = ri -ル = ru -レ = re -ロ = ro -ワ = wa -ヰ = we -ヱ = wi -ヲ = wo -ン = ŋ -ッ = q -ー = - -キャ = kǐa -キュ = kǐu -キョ = kǐo -ギャ = gǐa -ギュ = gǐu -ギョ = gǐo -シャ = sǐa -シュ = sǐu -ショ = sǐo -ジャ = zǐa -ジュ = zǐu -ジョ = zǐo -チャ = tǐa -チュ = tǐu -チョ = tǐo -ヂャ = dǐa -ヂュ = dǐu -ヂョ = dǐo -ニャ = nǐa -ニュ = nǐu -ニョ = nǐo -ヒャ = hǐa -ヒュ = hǐu -ヒョ = hǐo -ビャ = bǐa -ビュ = bǐu -ビョ = bǐo -ピャ = pǐa -ピュ = pǐu -ピョ = pǐo -ミャ = mǐa -ミュ = mǐu -ミョ = mǐo -リャ = rǐa -リュ = rǐu -リョ = rǐo -ァ = a -ィ = i -ゥ = u -ェ = e -ォ = o -ヮ = ʍ -ヴ = vu -ャ = ǐa -ュ = ǐu -ョ = ǐo diff --git a/egs/csj/ASR/local/disfluent_recogs_to_fluent.py b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py new file mode 100644 index 000000000..45c9c7656 --- /dev/null +++ b/egs/csj/ASR/local/disfluent_recogs_to_fluent.py @@ -0,0 +1,202 @@ +import argparse +from pathlib import Path + +import kaldialign +from lhotse import CutSet + +ARGPARSE_DESCRIPTION = """ +This helper code takes in a disfluent recogs file generated from icefall.utils.store_transcript, +compares it against a fluent transcript, and saves the results in a separate directory. +This is useful to compare disfluent models with fluent models on the same metric. + +""" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=ARGPARSE_DESCRIPTION, + ) + parser.add_argument( + "--recogs", + type=Path, + required=True, + help="Path to the recogs-XXX file generated by icefall.utils.store_transcript.", + ) + parser.add_argument( + "--cut", + type=Path, + required=True, + help="Path to the cut manifest to be compared to. Assumes that disfluent_tag exists in the custom dict.", + ) + parser.add_argument( + "--res-dir", type=Path, required=True, help="Path to save results" + ) + return parser.parse_args() + + +def d2f(stats): + """ + Compare the outputs of a disfluent model against a fluent reference. + Indicates a disfluent model's performance only on the content words + + CER^d_f = (sub_f + ins + del_f) / Nf + + """ + return stats["base"] / stats["Nf"] + + +def calc_cer(refs, hyps): + subs = { + "F": 0, + "D": 0, + } + ins = 0 + dels = { + "F": 0, + "D": 0, + } + cors = { + "F": 0, + "D": 0, + } + dis_ref_len = 0 + flu_ref_len = 0 + + for ref, hyp in zip(refs, hyps): + assert ( + ref[0] == hyp[0] + ), f"Expected ref cut id {ref[0]} to be the same as hyp cut id {hyp[0]}." + tag = ref[2].copy() + ref = ref[1] + dis_ref_len += len(ref) + # Remember that the 'D' and 'F' tags here refer to CSJ tags, not disfluent and fluent respectively. + flu_ref_len += len([t for t in tag if ("D" not in t and "F" not in t)]) + hyp = hyp[1] + ali = kaldialign.align(ref, hyp, "*") + tags = ["*" if r[0] == "*" else tag.pop(0) for r in ali] + for tag, (ref_word, hyp_word) in zip(tags, ali): + if "D" in tag or "F" in tag: + tag = "D" + else: + tag = "F" + + if ref_word == "*": + ins += 1 + elif hyp_word == "*": + dels[tag] += 1 + elif ref_word != hyp_word: + subs[tag] += 1 + else: + cors[tag] += 1 + + return { + "subs": subs, + "ins": ins, + "dels": dels, + "cors": cors, + "dis_ref_len": dis_ref_len, + "flu_ref_len": flu_ref_len, + } + + +def for_each_recogs(recogs_file: Path, refs, out_dir): + hyps = [] + with recogs_file.open() as fin: + for line in fin: + if "ref" in line: + continue + cutid, hyp = line.split(":\thyp=") + hyps.append((cutid, eval(hyp))) + + assert len(refs) == len( + hyps + ), f"Expected refs len {len(refs)} and hyps len {len(hyps)} to be equal." + stats = calc_cer(refs, hyps) + stat_table = ["tag,yes,no"] + + for cer_type in ["subs", "dels", "cors", "ins"]: + ret = f"{cer_type}" + for df in ["D", "F"]: + try: + ret += f",{stats[cer_type][df]}" + except TypeError: + # insertions do not belong to F or D, and is not subscriptable. + ret += f",{stats[cer_type]}," + break + stat_table.append(ret) + stat_table = "\n".join(stat_table) + + stats = { + "subd": stats["subs"]["D"], + "deld": stats["dels"]["D"], + "cord": stats["cors"]["D"], + "Nf": stats["flu_ref_len"], + "base": stats["subs"]["F"] + stats["ins"] + stats["dels"]["F"], + } + + cer = d2f(stats) + results = [ + f"{cer:.2%}", + f"Nf,{stats['Nf']}", + ] + results = "\n".join(results) + + with (out_dir / (recogs_file.stem + ".dfcer")).open("w") as fout: + fout.write(results) + fout.write("\n\n") + fout.write(stat_table) + + +def main(): + args = get_args() + recogs_file: Path = args.recogs + assert ( + recogs_file.is_file() or recogs_file.is_dir() + ), f"recogs_file cannot be found at {recogs_file}." + + args.res_dir.mkdir(parents=True, exist_ok=True) + + if recogs_file.is_file() and recogs_file.stem.startswith("recogs-"): + assert ( + "csj_cuts" in args.cut.name + ), f"Expected {args.cut} to be a cuts manifest." + + refs: CutSet = CutSet.from_file(args.cut) + refs = sorted( + [ + ( + e.id, + list(e.supervisions[0].custom["disfluent"]), + e.supervisions[0].custom["disfluent_tag"].split(","), + ) + for e in refs + ], + key=lambda x: x[0], + ) + for_each_recogs(recogs_file, refs, args.res_dir) + + elif recogs_file.is_dir(): + recogs_file_path = recogs_file + for partname in ["eval1", "eval2", "eval3", "excluded", "valid"]: + refs: CutSet = CutSet.from_file(args.cut / f"csj_cuts_{partname}.jsonl.gz") + refs = sorted( + [ + ( + r.id, + list(r.supervisions[0].custom["disfluent"]), + r.supervisions[0].custom["disfluent_tag"].split(","), + ) + for r in refs + ], + key=lambda x: x[0], + ) + for recogs_file in recogs_file_path.glob(f"recogs-{partname}-*.txt"): + for_each_recogs(recogs_file, refs, args.res_dir) + + else: + raise TypeError(f"Unrecognised recogs file provided: {recogs_file}") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c043cf853..924474d33 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -45,8 +45,8 @@ def get_parser(): def main(): args = get_parser() - for path in args.manifest_dir.glob("csj_cuts_*.jsonl.gz"): - + for part in ["eval1", "eval2", "eval3", "valid", "excluded", "train"]: + path = args.manifest_dir / f"csj_cuts_{part}.jsonl.gz" cuts: CutSet = load_manifest(path) print("\n---------------------------------\n") @@ -58,123 +58,271 @@ if __name__ == "__main__": main() """ -## eval1 -Cuts count: 1272 -Total duration (hh:mm:ss): 01:50:07 -Speech duration (hh:mm:ss): 01:50:07 (100.0%) -Duration statistics (seconds): -mean 5.2 -std 3.9 -min 0.2 -25% 1.9 -50% 4.0 -75% 8.1 -99% 14.3 -99.5% 14.7 -99.9% 16.0 -max 16.9 -Recordings available: 1272 -Features available: 1272 -Supervisions available: 1272 +csj_cuts_eval1.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1023 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:55:40 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.8 │ +├───────────────────────────┼──────────┤ +│ std │ 2.7 │ +├───────────────────────────┼──────────┤ +│ min │ 0.2 │ +├───────────────────────────┼──────────┤ +│ 25% │ 4.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.7 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1023 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1023 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- fluent (in 1272 cuts) -- disfluent (in 1272 cuts) -- number (in 1272 cuts) -- symbol (in 1272 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:55:40 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:55:40 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ -## eval2 -Cuts count: 1292 -Total duration (hh:mm:ss): 01:56:50 -Speech duration (hh:mm:ss): 01:56:50 (100.0%) -Duration statistics (seconds): -mean 5.4 -std 3.9 -min 0.1 -25% 2.1 -50% 4.6 -75% 8.6 -99% 14.1 -99.5% 15.2 -99.9% 16.1 -max 16.9 -Recordings available: 1292 -Features available: 1292 -Supervisions available: 1292 -SUPERVISION custom fields: -- fluent (in 1292 cuts) -- number (in 1292 cuts) -- symbol (in 1292 cuts) -- disfluent (in 1292 cuts) +--------------------------------- -## eval3 -Cuts count: 1385 -Total duration (hh:mm:ss): 01:19:21 -Speech duration (hh:mm:ss): 01:19:21 (100.0%) -Duration statistics (seconds): -mean 3.4 -std 3.0 -min 0.2 -25% 1.2 -50% 2.5 -75% 4.6 -99% 12.7 -99.5% 13.7 -99.9% 15.0 -max 15.9 -Recordings available: 1385 -Features available: 1385 -Supervisions available: 1385 +csj_cuts_eval2.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 1025 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 02:02:07 │ +├───────────────────────────┼──────────┤ +│ mean │ 7.1 │ +├───────────────────────────┼──────────┤ +│ std │ 2.5 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 5.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.9 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.1 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 1025 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 1025 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- number (in 1385 cuts) -- symbol (in 1385 cuts) -- fluent (in 1385 cuts) -- disfluent (in 1385 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 02:02:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 02:02:07 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ -## valid -Cuts count: 4000 -Total duration (hh:mm:ss): 05:08:09 -Speech duration (hh:mm:ss): 05:08:09 (100.0%) -Duration statistics (seconds): -mean 4.6 -std 3.8 -min 0.1 -25% 1.5 -50% 3.4 -75% 7.0 -99% 13.8 -99.5% 14.8 -99.9% 16.0 -max 17.3 -Recordings available: 4000 -Features available: 4000 -Supervisions available: 4000 -SUPERVISION custom fields: -- fluent (in 4000 cuts) -- symbol (in 4000 cuts) -- disfluent (in 4000 cuts) -- number (in 4000 cuts) +--------------------------------- -## train -Cuts count: 1291134 -Total duration (hh:mm:ss): 1596:37:27 -Speech duration (hh:mm:ss): 1596:37:27 (100.0%) -Duration statistics (seconds): -mean 4.5 -std 3.6 -min 0.0 -25% 1.6 -50% 3.3 -75% 6.4 -99% 14.0 -99.5% 14.8 -99.9% 16.6 -max 27.8 -Recordings available: 1291134 -Features available: 1291134 -Supervisions available: 1291134 +csj_cuts_eval3.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 865 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 01:26:44 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.3 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.3 │ +├───────────────────────────┼──────────┤ +│ 50% │ 6.8 │ +├───────────────────────────┼──────────┤ +│ 75% │ 8.7 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 865 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 865 │ +╘═══════════════════════════╧══════════╛ SUPERVISION custom fields: -- disfluent (in 1291134 cuts) -- fluent (in 1291134 cuts) -- symbol (in 1291134 cuts) -- number (in 1291134 cuts) +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 01:26:44 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 01:26:44 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_valid.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 3743 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 06:40:15 │ +├───────────────────────────┼──────────┤ +│ mean │ 6.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.0 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 3.9 │ +├───────────────────────────┼──────────┤ +│ 50% │ 7.4 │ +├───────────────────────────┼──────────┤ +│ 75% │ 9.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.1 │ +├───────────────────────────┼──────────┤ +│ max │ 11.8 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 3743 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 3743 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 06:40:15 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 06:40:15 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_excluded.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 980 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 00:56:06 │ +├───────────────────────────┼──────────┤ +│ mean │ 3.4 │ +├───────────────────────────┼──────────┤ +│ std │ 3.1 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 0.8 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.2 │ +├───────────────────────────┼──────────┤ +│ 75% │ 5.8 │ +├───────────────────────────┼──────────┤ +│ 99% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 9.9 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 10.0 │ +├───────────────────────────┼──────────┤ +│ max │ 10.0 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 980 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 980 │ +╘═══════════════════════════╧══════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 00:56:06 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 00:56:06 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ + +--------------------------------- + +csj_cuts_train.jsonl.gz: +Cut statistics: +╒═══════════════════════════╤════════════╕ +│ Cuts count: │ 914151 │ +├───────────────────────────┼────────────┤ +│ Total duration (hh:mm:ss) │ 1695:29:43 │ +├───────────────────────────┼────────────┤ +│ mean │ 6.7 │ +├───────────────────────────┼────────────┤ +│ std │ 2.9 │ +├───────────────────────────┼────────────┤ +│ min │ 0.1 │ +├───────────────────────────┼────────────┤ +│ 25% │ 4.6 │ +├───────────────────────────┼────────────┤ +│ 50% │ 7.5 │ +├───────────────────────────┼────────────┤ +│ 75% │ 8.9 │ +├───────────────────────────┼────────────┤ +│ 99% │ 11.0 │ +├───────────────────────────┼────────────┤ +│ 99.5% │ 11.0 │ +├───────────────────────────┼────────────┤ +│ 99.9% │ 11.1 │ +├───────────────────────────┼────────────┤ +│ max │ 18.0 │ +├───────────────────────────┼────────────┤ +│ Recordings available: │ 914151 │ +├───────────────────────────┼────────────┤ +│ Features available: │ 0 │ +├───────────────────────────┼────────────┤ +│ Supervisions available: │ 914151 │ +╘═══════════════════════════╧════════════╛ +SUPERVISION custom fields: +Speech duration statistics: +╒══════════════════════════════╤════════════╤══════════════════════╕ +│ Total speech duration │ 1695:29:43 │ 100.00% of recording │ +├──────────────────────────────┼────────────┼──────────────────────┤ +│ Total speaking time duration │ 1695:29:43 │ 100.00% of recording │ +├──────────────────────────────┼────────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧════════════╧══════════════════════╛ """ diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index 16107f543..58b197922 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -21,24 +21,14 @@ import logging from pathlib import Path from lhotse import CutSet +from lhotse.recipes.csj import CSJSDBParser ARGPARSE_DESCRIPTION = """ -This script gathers all training transcripts of the specified {trans_mode} type -and produces a token_list that would be output set of the ASR system. +This script gathers all training transcripts, parses them in disfluent mode, and produces a token list that would be the output set of the ASR system. -It splits transcripts by whitespace into lists, then, for each word in the -list, if the word does not appear in the list of user-defined multicharacter -strings, it further splits that word into individual characters to be counted -into the output token set. - -It outputs 4 files into the lang directory: -- trans_mode: the name of transcript mode. If trans_mode was not specified, - this will be an empty file. -- userdef_string: a list of user defined strings that should not be split - further into individual characters. By default, it contains "", "", - "" -- words_len: the total number of tokens in the output set. -- words.txt: a list of tokens in the output set. The length matches words_len. +It outputs 3 files into the lang directory: +- tokens.txt: a list of tokens in the output set. +- lang_type: a file that contains the string "char" """ @@ -50,98 +40,52 @@ def get_args(): ) parser.add_argument( - "--train-cut", type=Path, required=True, help="Path to the train cut" - ) - - parser.add_argument( - "--trans-mode", - type=str, - default=None, - help=( - "Name of the transcript mode to use. " - "If lang-dir is not set, this will also name the lang-dir" - ), + "train_cut", metavar="train-cut", type=Path, help="Path to the train cut" ) parser.add_argument( "--lang-dir", type=Path, - default=None, + default=Path("data/lang_char"), help=( "Name of lang dir. " "If not set, this will default to lang_char_{trans-mode}" ), ) - parser.add_argument( - "--userdef-string", - type=Path, - default=None, - help="Multicharacter strings that do not need to be split", - ) - return parser.parse_args() def main(): args = get_args() - logging.basicConfig( format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"), level=logging.INFO, ) - if not args.lang_dir: - p = "lang_char" - if args.trans_mode: - p += f"_{args.trans_mode}" - args.lang_dir = Path(p) + sysdef_string = set(["", "", ""]) - if args.userdef_string: - args.userdef_string = set(args.userdef_string.read_text().split()) - else: - args.userdef_string = set() + # Using disfluent parsing as fluent is a subset of disfluent + parser = CSJSDBParser() - sysdef_string = ["", "", ""] - args.userdef_string.update(sysdef_string) + token_set = set() + logging.info(f"Creating vocabulary from {args.train_cut}.") + train_cut: CutSet = CutSet.from_file(args.train_cut) + for cut in train_cut: + if "_sp" in cut.id: + continue - train_set: CutSet = CutSet.from_file(args.train_cut) - - words = set() - logging.info( - f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." - ) - for cut in train_set: - try: - text: str = ( - cut.supervisions[0].custom[args.trans_mode] - if args.trans_mode - else cut.supervisions[0].text - ) - except KeyError: - raise KeyError( - f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" - ) - for t in text.split(): - if t in args.userdef_string: - words.add(t) - else: - words.update(c for c in list(t)) - - words -= set(sysdef_string) - words = sorted(words) - words = [""] + words + ["", ""] + text: str = cut.supervisions[0].custom["raw"] + for w in parser.parse(text, sep=" ").split(" "): + token_set.update(w) + token_set = [""] + sorted(token_set - sysdef_string) + ["", ""] args.lang_dir.mkdir(parents=True, exist_ok=True) - (args.lang_dir / "words.txt").write_text( - "\n".join(f"{word}\t{i}" for i, word in enumerate(words)) + (args.lang_dir / "tokens.txt").write_text( + "\n".join(f"{t}\t{i}" for i, t in enumerate(token_set)) ) - (args.lang_dir / "words_len").write_text(f"{len(words)}") - - (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) - - (args.lang_dir / "trans_mode").write_text(args.trans_mode) + (args.lang_dir / "lang_type").write_text("char") logging.info("Done.") diff --git a/egs/csj/ASR/local/utils/asr_datamodule.py b/egs/csj/ASR/local/utils/asr_datamodule.py new file mode 100644 index 000000000..619820a75 --- /dev/null +++ b/egs/csj/ASR/local/utils/asr_datamodule.py @@ -0,0 +1,462 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AsrVariableTranscriptDataset(K2SpeechRecognitionDataset): + def __init__( + self, + *args, + transcript_mode: str = "", + return_cuts: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.transcript_mode = transcript_mode + self.return_cuts = True + self._return_cuts = return_cuts + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: + batch = super().__getitem__(cuts) + + if self.transcript_mode: + batch["supervisions"]["text"] = [ + supervision.custom[self.transcript_mode] + for cut in batch["supervisions"]["cut"] + for supervision in cut.supervisions + ] + + if not self._return_cuts: + del batch["supervisions"]["cut"] + + return batch + + +class CSJAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--transcript-mode", + type=str, + default="", + help="Mode of transcript in supervision to use.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--musan-dir", type=Path, help="Path to directory with musan cuts. " + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.musan_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = AsrVariableTranscriptDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = AsrVariableTranscriptDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = AsrVariableTranscriptDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + else: + validate = AsrVariableTranscriptDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + + test = AsrVariableTranscriptDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + transcript_mode=self.args.transcript_mode, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_train.jsonl.gz") + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get valid cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_valid.jsonl.gz") + + @lru_cache() + def excluded_cuts(self) -> CutSet: + logging.info("About to get excluded cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_excluded.jsonl.gz") + + @lru_cache() + def eval1_cuts(self) -> CutSet: + logging.info("About to get eval1 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval1.jsonl.gz") + + @lru_cache() + def eval2_cuts(self) -> CutSet: + logging.info("About to get eval2 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval2.jsonl.gz") + + @lru_cache() + def eval3_cuts(self) -> CutSet: + logging.info("About to get eval3 cuts") + return load_manifest_lazy(self.args.manifest_dir / "csj_cuts_eval3.jsonl.gz") diff --git a/egs/csj/ASR/local/utils/tokenizer.py b/egs/csj/ASR/local/utils/tokenizer.py new file mode 100644 index 000000000..c9be72be1 --- /dev/null +++ b/egs/csj/ASR/local/utils/tokenizer.py @@ -0,0 +1,253 @@ +import argparse +from pathlib import Path +from typing import Callable, List, Union + +import sentencepiece as spm +from k2 import SymbolTable + + +class Tokenizer: + text2word: Callable[[str], List[str]] + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser): + group = parser.add_argument_group(title="Lang related options") + + group.add_argument("--lang", type=Path, help="Path to lang directory.") + + group.add_argument( + "--lang-type", + type=str, + default=None, + help=( + "Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. " + "Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor" + ), + ) + + @staticmethod + def Load(lang_dir: Path, lang_type="", oov=""): + + if not lang_type: + assert (lang_dir / "lang_type").exists(), "lang_type not specified." + lang_type = (lang_dir / "lang_type").read_text().strip() + + tokenizer = None + + if lang_type == "bpe": + assert ( + lang_dir / "bpe.model" + ).exists(), f"No BPE .model could be found in {lang_dir}." + tokenizer = spm.SentencePieceProcessor() + tokenizer.Load(str(lang_dir / "bpe.model")) + elif lang_type == "char": + tokenizer = CharTokenizer(lang_dir, oov=oov) + else: + raise NotImplementedError(f"{lang_type} not supported at the moment.") + + return tokenizer + + load = Load + + def PieceToId(self, piece: str) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + piece_to_id = PieceToId + + def IdToPiece(self, id: int) -> str: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + id_to_piece = IdToPiece + + def GetPieceSize(self) -> int: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + get_piece_size = GetPieceSize + + def __len__(self) -> int: + return self.get_piece_size() + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def EncodeAsIds(self, input: str) -> List[int]: + return self.EncodeAsIdsBatch([input])[0] + + def EncodeAsPieces(self, input: str) -> List[str]: + return self.EncodeAsPiecesBatch([input])[0] + + def Encode( + self, input: Union[str, List[str]], out_type=int + ) -> Union[List, List[List]]: + if not input: + return [] + + if isinstance(input, list): + if out_type is int: + return self.EncodeAsIdsBatch(input) + if out_type is str: + return self.EncodeAsPiecesBatch(input) + + if out_type is int: + return self.EncodeAsIds(input) + if out_type is str: + return self.EncodeAsPieces(input) + + encode = Encode + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def DecodeIds(self, input: List[int]) -> str: + return self.DecodeIdsBatch([input])[0] + + def DecodePieces(self, input: List[str]) -> str: + return self.DecodePiecesBatch([input])[0] + + def Decode( + self, + input: Union[int, List[int], List[str], List[List[int]], List[List[str]]], + ) -> Union[List[str], str]: + + if not input: + return "" + + if isinstance(input, int): + return self.id_to_piece(input) + elif isinstance(input, str): + raise TypeError( + "Unlike spm.SentencePieceProcessor, cannot decode from type str." + ) + + if isinstance(input[0], list): + if not input[0] or isinstance(input[0][0], int): + return self.DecodeIdsBatch(input) + + if isinstance(input[0][0], str): + return self.DecodePiecesBatch(input) + + if isinstance(input[0], int): + return self.DecodeIds(input) + if isinstance(input[0], str): + return self.DecodePieces(input) + + raise RuntimeError("Unknown input type") + + decode = Decode + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + raise NotImplementedError( + "You need to implement this function in the child class." + ) + + def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]: + if isinstance(input, list): + return self.SplitBatch(input) + elif isinstance(input, str): + return self.SplitBatch([input])[0] + raise RuntimeError("Unknown input type") + + split = Split + + +class CharTokenizer(Tokenizer): + def __init__(self, lang_dir: Path, oov="", sep=""): + assert ( + lang_dir / "tokens.txt" + ).exists(), f"tokens.txt could not be found in {lang_dir}." + token_table = SymbolTable.from_file(lang_dir / "tokens.txt") + assert ( + "#0" not in token_table + ), "This tokenizer does not support disambig symbols." + self._id2sym = token_table._id2sym + self._sym2id = token_table._sym2id + self.oov = oov + self.oov_id = self._sym2id[oov] + self.sep = sep + if self.sep: + self.text2word = lambda x: x.split(self.sep) + else: + self.text2word = lambda x: list(x.replace(" ", "")) + + def piece_to_id(self, piece: str) -> int: + try: + return self._sym2id[piece] + except KeyError: + return self.oov_id + + def id_to_piece(self, id: int) -> str: + return self._id2sym[id] + + def get_piece_size(self) -> int: + return len(self._sym2id) + + def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]: + return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input] + + def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]: + return [ + [i if i in self._sym2id else self.oov for i in self.text2word(text)] + for text in input + ] + + def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]: + return [self.sep.join(self.id_to_piece(i) for i in text) for text in input] + + def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]: + return [self.sep.join(text) for text in input] + + def SplitBatch(self, input: List[str]) -> List[List[str]]: + return [self.text2word(text) for text in input] + + +def test_CharTokenizer(): + test_single_string = "こんにちは" + test_multiple_string = [ + "今日はいい天気ですよね", + "諏訪湖は綺麗でしょう", + "这在词表外", + "分かち 書き に し た 文章 です", + "", + ] + test_empty_string = "" + sp = Tokenizer.load(Path("lang_char"), "char", oov="") + splitter = sp.split + print(sp.encode(test_single_string, out_type=str)) + print(sp.encode(test_single_string, out_type=int)) + print(sp.encode(test_multiple_string, out_type=str)) + print(sp.encode(test_multiple_string, out_type=int)) + print(sp.encode(test_empty_string, out_type=str)) + print(sp.encode(test_empty_string, out_type=int)) + print(sp.decode(sp.encode(test_single_string, out_type=str))) + print(sp.decode(sp.encode(test_single_string, out_type=int))) + print(sp.decode(sp.encode(test_multiple_string, out_type=str))) + print(sp.decode(sp.encode(test_multiple_string, out_type=int))) + print(sp.decode(sp.encode(test_empty_string, out_type=str))) + print(sp.decode(sp.encode(test_empty_string, out_type=int))) + print(splitter(test_single_string)) + print(splitter(test_multiple_string)) + print(splitter(test_empty_string)) + + +if __name__ == "__main__": + test_CharTokenizer() diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh index c4ce91984..52339bb35 100755 --- a/egs/csj/ASR/prepare.sh +++ b/egs/csj/ASR/prepare.sh @@ -32,7 +32,7 @@ # - speech # # By default, this script produces the original transcript like kaldi and espnet. Optionally, you -# can generate other transcript formats by supplying your own config files. A few examples of these +# can add other transcript formats by supplying your own config files. A few examples of these # config files can be found in local/conf. # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 @@ -44,10 +44,10 @@ nj=8 stage=-1 stop_stage=100 -csj_dir=/mnt/minami_data_server/t2131178/corpus/CSJ -musan_dir=/mnt/minami_data_server/t2131178/corpus/musan/musan -trans_dir=$csj_dir/retranscript -csj_fbank_dir=/mnt/host/csj_data/fbank +csj_dir=/mnt/host/corpus/csj +musan_dir=/mnt/host/corpus/musan/musan +trans_dir=$csj_dir/transcript +csj_fbank_dir=/mnt/host/corpus/csj/fbank musan_fbank_dir=$musan_dir/fbank csj_manifest_dir=data/manifests musan_manifest_dir=$musan_dir/manifests @@ -63,12 +63,8 @@ log() { if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare CSJ manifest" - # If you want to generate more transcript modes, append the path to those config files at c. - # Example: lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -c local/conf/disfluent.ini - # NOTE: In case multiple config files are supplied, the second config file and onwards will inherit - # the segment boundaries of the first config file. if [ ! -e $csj_manifest_dir/.csj.done ]; then - lhotse prepare csj $csj_dir $trans_dir $csj_manifest_dir -j 4 + lhotse prepare csj $csj_dir $csj_manifest_dir -t $trans_dir -j 16 touch $csj_manifest_dir/.csj.done fi fi @@ -88,32 +84,24 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then python local/compute_fbank_csj.py --manifest-dir $csj_manifest_dir \ --fbank-dir $csj_fbank_dir parts=( - train - valid eval1 eval2 eval3 + valid + excluded + train ) for part in ${parts[@]}; do - python local/validate_manifest.py --manifest $csj_manifest_dir/csj_cuts_$part.jsonl.gz + python local/validate_manifest.py --manifest $csj_fbank_dir/csj_cuts_$part.jsonl.gz done touch $csj_fbank_dir/.csj-validated.done fi fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Prepare CSJ lang" - modes=disfluent - - # If you want prepare the lang directory for other transcript modes, just append - # the names of those modes behind. An example is shown as below:- - # modes="$modes fluent symbol number" - - for mode in ${modes[@]}; do - python local/prepare_lang_char.py --trans-mode $mode \ - --train-cut $csj_manifest_dir/csj_cuts_train.jsonl.gz \ - --lang-dir lang_char_$mode - done + log "Stage 4: Prepare CSJ lang_char" + python local/prepare_lang_char.py $csj_fbank_dir/csj_cuts_train.jsonl.gz + python local/add_transcript_mode.py -f $csj_fbank_dir -c local/conf/fluent.ini local/conf/number.ini fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then @@ -128,6 +116,6 @@ fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Show manifest statistics" - python local/display_manifest_statistics.py --manifest-dir $csj_manifest_dir > $csj_manifest_dir/manifest_statistics.txt - cat $csj_manifest_dir/manifest_statistics.txt + python local/display_manifest_statistics.py --manifest-dir $csj_fbank_dir > $csj_fbank_dir/manifest_statistics.txt + cat $csj_fbank_dir/manifest_statistics.txt fi diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py new file mode 100644 index 000000000..f5235207a --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/TelegramStreamIO.py @@ -0,0 +1,76 @@ +import logging +from configparser import ConfigParser + +import requests + + +def escape_html(text: str): + """ + Escapes all html characters in text + :param str text: + :rtype: str + """ + return text.replace("&", "&").replace("<", "<").replace(">", ">") + + +class TelegramStreamIO(logging.Handler): + + API_ENDPOINT = "https://api.telegram.org" + MAX_MESSAGE_LEN = 4096 + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s at %(funcName)s " + "(line %(lineno)s):\n\n%(message)s" + ) + + def __init__(self, tg_configfile: str): + super(TelegramStreamIO, self).__init__() + config = ConfigParser() + if not config.read(tg_configfile): + raise FileNotFoundError( + f"{tg_configfile} not found. " "Retry without --telegram-cred flag." + ) + config = config["TELEGRAM"] + token = config["token"] + self.chat_id = config["chat_id"] + self.url = f"{self.API_ENDPOINT}/bot{token}/sendMessage" + + @staticmethod + def setup_logger(params): + if not params.telegram_cred: + return + formatter = logging.Formatter( + f"{params.exp_dir.name} %(asctime)s \n%(message)s" + ) + tg = TelegramStreamIO(params.telegram_cred) + tg.setLevel(logging.WARN) + tg.setFormatter(formatter) + logging.getLogger("").addHandler(tg) + + def emit(self, record: logging.LogRecord): + """ + Emit a record. + Send the record to the Web server as a percent-encoded dictionary + """ + data = { + "chat_id": self.chat_id, + "text": self.format(self.mapLogRecord(record)), + "parse_mode": "HTML", + } + try: + requests.get(self.url, json=data) + # return response.json() + except Exception as e: + logging.error(f"Failed to send telegram message: {repr(e)}") + pass + + def mapLogRecord(self, record): + """ + Default implementation of mapping the log record into a dict + that is sent as the CGI data. Overwrite in your class. + Contributed by Franz Glasner. + """ + + for k, v in record.__dict__.items(): + if isinstance(v, str): + setattr(record, k, escape_html(v)) + return record diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py new file mode 120000 index 000000000..a48591198 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py @@ -0,0 +1 @@ +../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 000000000..d7349b0a3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py new file mode 100755 index 000000000..19d3c79c8 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -0,0 +1,852 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --lang data/lang_char \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method beam_search \ + --lang data/lang_char \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search \ + --lang data/lang_char \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --lang data/lang_char \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --lang data/lang_char \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --lang data/lang_char \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --lang data/lang_char \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir. It should contain at least a word table.", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=30, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(sp.text2word(hyp)) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.text2word(sp.decode(hyp))) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = sp.text2word(ref_text) + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + return test_set_wers + + +@torch.no_grad() +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + if not params.res_dir: + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and are defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + model.encoder.decode_chunk_size, + params.decode_chunk_len, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + decoding_graph = None + word_table = None + + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + csj_corpus = CSJAsrDataModule(args) + + for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: + results_dict = decode_dataset( + dl=csj_corpus.test_dataloaders(getattr(csj_corpus, f"{subdir}_cuts")()), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, + test_set_name=subdir, + results_dict=results_dict, + ) + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 000000000..ca8fed319 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 000000000..1ce277aa6 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 000000000..cb673b3eb --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 100644 index 000000000..2d45ecca3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/csj/ASR + ./pruned_transducer_stateless7_streaming/decode.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --lang data/lang_char + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 + # You will find the pre-trained model in icefall-asr-csj-pruned-transducer-stateless7-230208/exp_fluent +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py new file mode 100644 index 000000000..ab7c8748a --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 + +""" +Usage: +# use -O to skip assertions and avoid some of the TracerWarnings +python -O pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from tokenizer import Tokenizer +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, + params: AttributeDict, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + decode_chunk_len = params.decode_chunk_len # before subsampling + pad_length = 7 + s = f"decode_chunk_len: {decode_chunk_len}" + logging.info(s) + assert encoder_model.decode_chunk_size == decode_chunk_len // 2, ( + encoder_model.decode_chunk_size, + decode_chunk_len, + ) + + T = decode_chunk_len + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder_model.get_init_state(device=x.device) + + encoder_model.__class__.forward = encoder_model.__class__.streaming_forward + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename, params) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py new file mode 100644 index 000000000..d84cf04a3 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_streaming/jit_trace_export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --decode-chunk-len 32 + +Usage of this script: + +./pruned_transducer_stateless7_streaming/jit_trace_pretrained.py \ + --encoder-model-filename ./pruned_transducer_stateless7_streaming/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./pruned_transducer_stateless7_streaming/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./pruned_transducer_stateless7_streaming/exp/joiner_jit_trace.pt \ + --lang data/lang_char \ + --decode-chunk-len 32 \ + /path/to/foo.wav \ +""" + +import argparse +import logging +from typing import List, Optional + +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = Tokenizer.load(args.lang, args.lang_type) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = args.decode_chunk_len + assert encoder.decode_chunk_size == chunk_length // 2, ( + encoder.decode_chunk_size, + chunk_length, + ) + + # we subsample features with ((x_len - 7) // 2 + 1) // 2 + pad_length = 7 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + + states = encoder.get_init_state(device) + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + frames = torch.cat(frames, dim=0).unsqueeze(0) + x_lens = torch.tensor([T], dtype=torch.int32) + encoder_out, out_lens, states = encoder( + x=frames, + x_lens=x_lens, + states=states, + ) + num_processed_frames += chunk_length + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 000000000..482ebcfef --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 000000000..16c2bf28d --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 000000000..522bbaff9 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 100644 index 000000000..932026868 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_streaming/export.py \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_streaming/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ + --lang data/lang_char \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_streaming/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = Tokenizer.load(params.lang, params.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 000000000..a7ef73bcb --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 000000000..566c317ff --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 000000000..2adf271c1 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py new file mode 100755 index 000000000..9700dd89e --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -0,0 +1,597 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./pruned_transducer_stateless7_streaming/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --decode-chunk-len 32 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --decoding_method greedy_search \ + --lang data/lang_char \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import torch +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decode import save_results +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from tokenizer import Tokenizer +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import stack_states, unstack_states + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import AttributeDict, setup_logger, str2bool + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--gpu", + type=int, + default=0, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--decoding-graph", + type=str, + default="", + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--res-dir", + type=Path, + default=None, + help="The path to save results.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + + features = [] + feature_lens = [] + states = [] + processed_lens = [] + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling + # factor in encoders is 8. + # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. + tail_length = 23 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + processed_lens = torch.tensor(processed_lens, device=device) + + encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( + x=features, + x_lens=feature_lens, + states=states, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: Tokenizer, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 50 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = model.encoder.get_init_state(device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) + decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + sp.text2word(decode_streams[i].ground_truth), + sp.text2word(sp.decode(decode_streams[i].decoding_result())), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +@torch.no_grad() +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + if not params.res_dir: + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + # for streaming + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", params.gpu) + + logging.info(f"Device: {device}") + + sp = Tokenizer.load(params.lang, params.lang_type) + + # and is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_graph: + decoding_graph = k2.Fsa.from_dict( + torch.load(params.decoding_graph, map_location=device) + ) + elif params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + args.return_cuts = True + csj_corpus = CSJAsrDataModule(args) + + for subdir in ["eval1", "eval2", "eval3", "excluded", "valid"]: + results_dict = decode_dataset( + cuts=getattr(csj_corpus, f"{subdir}_cuts")(), + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + tot_err = save_results( + params=params, test_set_name=subdir, results_dict=results_dict + ) + + with ( + params.res_dir + / ( + f"{subdir}-{params.decode_chunk_len}" + f"_{params.avg}_{params.epoch}.cer" + ) + ).open("w") as fout: + if len(tot_err) == 1: + fout.write(f"{tot_err[0][1]}") + else: + fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py new file mode 100755 index 000000000..0a82ccfa4 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/test_model.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/csj/ASR + python ./pruned_transducer_stateless7_streaming/test_model.py +""" + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import get_params, get_transducer_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + # Test jit script + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + print("Using torch.jit.script") + model = torch.jit.script(model) + + +def test_model_jit_trace(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + params.num_left_chunks = 4 + params.short_chunk_size = 50 + params.decode_chunk_len = 32 + model = get_transducer_model(params) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + convert_scaled_to_non_scaled(model, inplace=True) + + # Test encoder + def _test_encoder(): + encoder = model.encoder + assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( + encoder.decode_chunk_size, + params.decode_chunk_len, + ) + T = params.decode_chunk_len + 7 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + states = encoder.get_init_state(device=x.device) + encoder.__class__.forward = encoder.__class__.streaming_forward + traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) + + states1 = encoder.get_init_state(device=x.device) + states2 = traced_encoder.get_init_state(device=x.device) + for i in range(5): + x = torch.randn(1, T, 80, dtype=torch.float32) + x_lens = torch.full((1,), T, dtype=torch.int32) + y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) + y2, _, states2 = traced_encoder(x, x_lens, states2) + assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) + + # Test decoder + def _test_decoder(): + decoder = model.decoder + y = torch.zeros(10, decoder.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_decoder = torch.jit.trace(decoder, (y, need_pad)) + d1 = decoder(y, need_pad) + d2 = traced_decoder(y, need_pad) + assert torch.equal(d1, d2), (d1 - d2).abs().mean() + + # Test joiner + def _test_joiner(): + joiner = model.joiner + encoder_out_dim = joiner.encoder_proj.weight.shape[1] + decoder_out_dim = joiner.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) + j1 = joiner(encoder_out, decoder_out) + j2 = traced_joiner(encoder_out, decoder_out) + assert torch.equal(j1, j2), (j1 - j2).abs().mean() + + _test_encoder() + _test_decoder() + _test_joiner() + + +def main(): + test_model() + test_model_jit_trace() + + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py new file mode 120000 index 000000000..958c99e85 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/tokenizer.py @@ -0,0 +1 @@ +../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py new file mode 100755 index 000000000..601de2c41 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -0,0 +1,1304 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_streaming/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_streaming/exp \ + --lang data/lang_char \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import math +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import CSJAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from tokenizer import Tokenizer +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LOG_EPS = math.log(1e-10) + +try: + from TelegramStreamIO import TelegramStreamIO + + HAS_TELEGRAM = True +except ImportError: + HAS_TELEGRAM = False + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=50, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + parser.add_argument( + "--decode-chunk-len", + type=int, + default=32, + help="The chunk size for decoding (in frames before subsampling)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument("--debug", action="store_true", help="Use hardcoded arguments") + + parser.add_argument( + "--telegram-cred", + type=Path, + default=None, + help="Telegram credentials to report progress in telegram", + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="pruned_transducer_stateless7_streaming/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--pad-feature", + type=int, + default=0, + help=""" + Number of frames to pad at the end. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + num_left_chunks=params.num_left_chunks, + short_chunk_size=params.short_chunk_size, + decode_chunk_size=params.decode_chunk_len // 2, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.pad_feature: + feature_lens += params.pad_feature + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.pad_feature), + value=LOG_EPS, + ) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: Tokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: # noqa + logging.error(e, exc_info=True) + display_and_save_batch(batch, params=params, sp=sp) + raise e + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + if HAS_TELEGRAM and batch_idx in [0, 500] and not rank: + logging.warning( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + else: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + if ( + HAS_TELEGRAM + and batch_idx % (params.valid_interval * 3) == 0 + and not rank + ): + log_mode = logging.warning + else: + log_mode = logging.info + log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") + log_mode( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, master_port=params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + if HAS_TELEGRAM and params.telegram_cred: + TelegramStreamIO.setup_logger(params) + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = Tokenizer.load(args.lang, args.lang_type) + + # is defined in local/prepare_lang_char.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 0.3 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.info( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + csj_corpus = CSJAsrDataModule(args) + train_cuts = csj_corpus.train_cuts() + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = csj_corpus.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = csj_corpus.valid_cuts() + valid_dl = csj_corpus.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: Tokenizer, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: Tokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + CSJAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 000000000..ec183baa7 --- /dev/null +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file