From 8aaa9761e46c6d71e63096160ddee0197f64a5ff Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 27 Feb 2023 21:23:04 +0800 Subject: [PATCH 1/9] Add doc about exporting streaming zipformer to sherpa-ncnn (#927) --- docs/source/conf.py | 1 + ...t-zipformer-transducer-for-ncnn-output.txt | 74 ++++ ...ncnn-decode-zipformer-transducer-libri.txt | 7 + .../export-ncnn-conv-emformer.rst | 4 + .../model-export/export-ncnn-zipformer.rst | 383 ++++++++++++++++++ docs/source/model-export/export-ncnn.rst | 2 + docs/source/model-export/export-onnx.rst | 16 + 7 files changed, 487 insertions(+) create mode 100644 docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt create mode 100644 docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt create mode 100644 docs/source/model-export/export-ncnn-zipformer.rst diff --git a/docs/source/conf.py b/docs/source/conf.py index 6452c5d6d..6901dec02 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -81,6 +81,7 @@ todo_include_todos = True rst_epilog = """ .. _sherpa-ncnn: https://github.com/k2-fsa/sherpa-ncnn +.. _sherpa-onnx: https://github.com/k2-fsa/sherpa-onnx .. _icefall: https://github.com/k2-fsa/icefall .. _git-lfs: https://git-lfs.com/ .. _ncnn: https://github.com/tencent/ncnn diff --git a/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt new file mode 100644 index 000000000..25874a414 --- /dev/null +++ b/docs/source/model-export/code/export-zipformer-transducer-for-ncnn-output.txt @@ -0,0 +1,74 @@ +2023-02-27 20:23:07,473 INFO [export-for-ncnn.py:246] device: cpu +2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:255] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.23.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '62e404dd3f3a811d73e424199b3408e309c06e1a', 'k2-git-date': 'Mon Jan 30 10:26:16 2023', 'lhotse-version': '1.12.0.dev+missing.version.file', 'torch-version': '1.10.0+cu102', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '6d7a559-clean', 'icefall-git-date': 'Thu Feb 16 19:47:54 2023', 'icefall-path': '/star-fj/fangjun/open-source/icefall-2', 'k2-path': '/star-fj/fangjun/open-source/k2/k2/python/k2/__init__.py', 'lhotse-path': '/star-fj/fangjun/open-source/lhotse/lhotse/__init__.py', 'hostname': 'de-74279-k2-train-3-1220120619-7695ff496b-s9n4w', 'IP address': '10.177.6.147'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'exp_dir': PosixPath('icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp'), 'bpe_model': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model', 'context_size': 2, 'use_averaged_model': False, 'num_encoder_layers': '2,4,3,2,4', 'feedforward_dims': '1024,1024,2048,2048,1024', 'nhead': '8,8,8,8,8', 'encoder_dims': '384,384,384,384,384', 'attention_dims': '192,192,192,192,192', 'encoder_unmasked_dims': '256,256,256,256,256', 'zipformer_downsampling_factors': '1,2,4,8,2', 'cnn_module_kernels': '31,31,31,31,31', 'decoder_dim': 512, 'joiner_dim': 512, 'short_chunk_size': 50, 'num_left_chunks': 4, 'decode_chunk_len': 32, 'blank_id': 0, 'vocab_size': 500} +2023-02-27 20:23:07,477 INFO [export-for-ncnn.py:257] About to create model +2023-02-27 20:23:08,023 INFO [zipformer2.py:419] At encoder stack 4, which has downsampling_factor=2, we will combine the outputs of layers 1 and 3, with downsampling_factors=2 and 8. +2023-02-27 20:23:08,037 INFO [checkpoint.py:112] Loading checkpoint from icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/epoch-99.pt +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:346] encoder parameters: 68944004 +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:347] decoder parameters: 260096 +2023-02-27 20:23:08,655 INFO [export-for-ncnn.py:348] joiner parameters: 716276 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:349] total parameters: 69920376 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:351] Using torch.jit.trace() +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:353] Exporting encoder +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:174] decode_chunk_len: 32 +2023-02-27 20:23:08,656 INFO [export-for-ncnn.py:175] T: 39 +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1344: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_len.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1348: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_avg.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1352: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1356: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1360: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val2.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1364: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_conv1.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1368: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_conv2.size(0) == self.num_layers, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1373: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert self.left_context_len == cached_key.shape[1], ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1884: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert self.x_size == x.size(0), (self.x_size, x.size(0)) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2442: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == self.left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2449: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == cached_val.shape[0], ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2469: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_key.shape[0] == left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2473: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.shape[0] == left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2483: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert kv_len == k.shape[0], (kv_len, k.shape) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2570: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2926: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cache.shape == (x.size(0), x.size(1), self.lorder), ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2652: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert x.shape[0] == self.x_size, (x.shape[0], self.x_size) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2653: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert x.shape[2] == self.embed_dim, (x.shape[2], self.embed_dim) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:2666: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert cached_val.shape[0] == self.left_context_len, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1543: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[0] == self.in_x_size, (src.shape[0], self.in_x_size) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1637: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[0] == self.in_x_size, ( +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1643: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src.shape[2] == self.in_channels, (src.shape[2], self.in_channels) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1571: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + if src.shape[0] != self.in_x_size: +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1763: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1779: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src1.shape[-1] == self.dim1, (src1.shape[-1], self.dim1) +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py:1780: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert src2.shape[-1] == self.dim2, (src2.shape[-1], self.dim2) +/star-fj/fangjun/py38/lib/python3.8/site-packages/torch/jit/_trace.py:958: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior. + module._c._create_method_from_trace( +2023-02-27 20:23:19,640 INFO [export-for-ncnn.py:182] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt +2023-02-27 20:23:19,646 INFO [export-for-ncnn.py:357] Exporting decoder +/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py:102: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! + assert embedding_out.size(-1) == self.context_size +2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:204] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt +2023-02-27 20:23:19,686 INFO [export-for-ncnn.py:361] Exporting joiner +2023-02-27 20:23:19,735 INFO [export-for-ncnn.py:231] Saved to icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt diff --git a/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt new file mode 100644 index 000000000..5b4969e0f --- /dev/null +++ b/docs/source/model-export/code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt @@ -0,0 +1,7 @@ +2023-02-27 20:43:40,283 INFO [streaming-ncnn-decode.py:349] {'tokens': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt', 'encoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param', 'encoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin', 'decoder_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param', 'decoder_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin', 'joiner_param_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param', 'joiner_bin_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin', 'sound_filename': './icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav'} +2023-02-27 20:43:41,260 INFO [streaming-ncnn-decode.py:357] Constructing Fbank computer +2023-02-27 20:43:41,264 INFO [streaming-ncnn-decode.py:360] Reading sound files: ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav +2023-02-27 20:43:41,269 INFO [streaming-ncnn-decode.py:365] torch.Size([106000]) +2023-02-27 20:43:41,280 INFO [streaming-ncnn-decode.py:372] number of states: 35 +2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:410] ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav +2023-02-27 20:43:45,026 INFO [streaming-ncnn-decode.py:411] AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS diff --git a/docs/source/model-export/export-ncnn-conv-emformer.rst b/docs/source/model-export/export-ncnn-conv-emformer.rst index 133915da7..12b370143 100644 --- a/docs/source/model-export/export-ncnn-conv-emformer.rst +++ b/docs/source/model-export/export-ncnn-conv-emformer.rst @@ -166,6 +166,10 @@ Next, we use the following code to export our model: --memory-size 32 \ --encoder-dim 512 +.. caution:: + + If your model has different configuration parameters, please change them accordingly. + .. hint:: We have renamed our model to ``epoch-30.pt`` so that we can use ``--epoch 30``. diff --git a/docs/source/model-export/export-ncnn-zipformer.rst b/docs/source/model-export/export-ncnn-zipformer.rst new file mode 100644 index 000000000..5c81d25ca --- /dev/null +++ b/docs/source/model-export/export-ncnn-zipformer.rst @@ -0,0 +1,383 @@ +.. _export_streaming_zipformer_transducer_models_to_ncnn: + +Export streaming Zipformer transducer models to ncnn +---------------------------------------------------- + +We use the pre-trained model from the following repository as an example: + +``_ + +We will show you step by step how to export it to `ncnn`_ and run it with `sherpa-ncnn`_. + +.. hint:: + + We use ``Ubuntu 18.04``, ``torch 1.13``, and ``Python 3.8`` for testing. + +.. caution:: + + Please use a more recent version of PyTorch. For instance, ``torch 1.8`` + may ``not`` work. + +1. Download the pre-trained model +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + You have to install `git-lfs`_ before you continue. + + +.. code-block:: bash + + cd egs/librispeech/ASR + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + git lfs pull --include "exp/pretrained.pt" + git lfs pull --include "data/lang_bpe_500/bpe.model" + + cd .. + +.. note:: + + We downloaded ``exp/pretrained-xxx.pt``, not ``exp/cpu-jit_xxx.pt``. + +In the above code, we downloaded the pre-trained model into the directory +``egs/librispeech/ASR/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29``. + +2. Install ncnn and pnnx +^^^^^^^^^^^^^^^^^^^^^^^^ + +Please refer to :ref:`export_for_ncnn_install_ncnn_and_pnnx` . + + +3. Export the model via torch.jit.trace() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +First, let us rename our pre-trained model: + +.. code-block:: + + cd egs/librispeech/ASR + + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp + + ln -s pretrained.pt epoch-99.pt + + cd ../.. + +Next, we use the following code to export our model: + +.. code-block:: bash + + dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29 + + ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + --exp-dir $dir/exp \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + \ + --decode-chunk-len 32 \ + --num-left-chunks 4 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +.. caution:: + + If your model has different configuration parameters, please change them accordingly. + +.. hint:: + + We have renamed our model to ``epoch-99.pt`` so that we can use ``--epoch 99``. + There is only one pre-trained model, so we use ``--avg 1 --use-averaged-model 0``. + + If you have trained a model by yourself and if you have all checkpoints + available, please first use ``decode.py`` to tune ``--epoch --avg`` + and select the best combination with with ``--use-averaged-model 1``. + +.. note:: + + You will see the following log output: + + .. literalinclude:: ./code/export-zipformer-transducer-for-ncnn-output.txt + + The log shows the model has ``69920376`` parameters, i.e., ``~69.9 M``. + + .. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt + -rw-r--r-- 1 kuangfangjun root 269M Jan 12 12:53 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/pretrained.pt + + You can see that the file size of the pre-trained model is ``269 MB``, which + is roughly equal to ``69920376*4/1024/1024 = 266.725 MB``. + +After running ``pruned_transducer_stateless7_streaming/export-for-ncnn.py``, +we will get the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*pnnx.pt + + -rw-r--r-- 1 kuangfangjun root 1022K Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 266M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.pt + -rw-r--r-- 1 kuangfangjun root 2.8M Feb 27 20:23 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.pt + +.. _zipformer-transducer-step-4-export-torchscript-model-via-pnnx: + +4. Export torchscript model via pnnx +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. hint:: + + Make sure you have set up the ``PATH`` environment variable + in :ref:`export_for_ncnn_install_ncnn_and_pnnx`. Otherwise, + it will throw an error saying that ``pnnx`` could not be found. + +Now, it's time to export our models to `ncnn`_ via ``pnnx``. + +.. code-block:: + + cd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/ + + pnnx ./encoder_jit_trace-pnnx.pt + pnnx ./decoder_jit_trace-pnnx.pt + pnnx ./joiner_jit_trace-pnnx.pt + +It will generate the following files: + +.. code-block:: bash + + ls -lh icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/*ncnn*{bin,param} + + -rw-r--r-- 1 kuangfangjun root 509K Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 437 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 133M Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 152K Feb 27 20:30 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param + -rw-r--r-- 1 kuangfangjun root 1.4M Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin + -rw-r--r-- 1 kuangfangjun root 488 Feb 27 20:31 icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param + +There are two types of files: + +- ``param``: It is a text file containing the model architectures. You can + use a text editor to view its content. +- ``bin``: It is a binary file containing the model parameters. + +We compare the file sizes of the models below before and after converting via ``pnnx``: + +.. see https://tableconvert.com/restructuredtext-generator + ++----------------------------------+------------+ +| File name | File size | ++==================================+============+ +| encoder_jit_trace-pnnx.pt | 266 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.pt | 1022 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.pt | 2.8 MB | ++----------------------------------+------------+ +| encoder_jit_trace-pnnx.ncnn.bin | 133 MB | ++----------------------------------+------------+ +| decoder_jit_trace-pnnx.ncnn.bin | 509 KB | ++----------------------------------+------------+ +| joiner_jit_trace-pnnx.ncnn.bin | 1.4 MB | ++----------------------------------+------------+ + +You can see that the file sizes of the models after conversion are about one half +of the models before conversion: + + - encoder: 266 MB vs 133 MB + - decoder: 1022 KB vs 509 KB + - joiner: 2.8 MB vs 1.4 MB + +The reason is that by default ``pnnx`` converts ``float32`` parameters +to ``float16``. A ``float32`` parameter occupies 4 bytes, while it is 2 bytes +for ``float16``. Thus, it is ``twice smaller`` after conversion. + +.. hint:: + + If you use ``pnnx ./encoder_jit_trace-pnnx.pt fp16=0``, then ``pnnx`` + won't convert ``float32`` to ``float16``. + +5. Test the exported models in icefall +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. note:: + + We assume you have set up the environment variable ``PYTHONPATH`` when + building `ncnn`_. + +Now we have successfully converted our pre-trained model to `ncnn`_ format. +The generated 6 files are what we need. You can use the following code to +test the converted models: + +.. code-block:: bash + + python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ + --tokens ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/joiner_jit_trace-pnnx.ncnn.bin \ + ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/test_wavs/1089-134686-0001.wav + +.. hint:: + + `ncnn`_ supports only ``batch size == 1``, so ``streaming-ncnn-decode.py`` accepts + only 1 wave file as input. + +The output is given below: + +.. literalinclude:: ./code/test-streaming-ncnn-decode-zipformer-transducer-libri.txt + +Congratulations! You have successfully exported a model from PyTorch to `ncnn`_! + +.. _zipformer-modify-the-exported-encoder-for-sherpa-ncnn: + +6. Modify the exported encoder for sherpa-ncnn +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to use the exported models in `sherpa-ncnn`_, we have to modify +``encoder_jit_trace-pnnx.ncnn.param``. + +Let us have a look at the first few lines of ``encoder_jit_trace-pnnx.ncnn.param``: + +.. code-block:: + + 7767517 + 2028 2547 + Input in0 0 1 in0 + +**Explanation** of the above three lines: + + 1. ``7767517``, it is a magic number and should not be changed. + 2. ``2028 2547``, the first number ``2028`` specifies the number of layers + in this file, while ``2547`` specifies the number of intermediate outputs + of this file + 3. ``Input in0 0 1 in0``, ``Input`` is the layer type of this layer; ``in0`` + is the layer name of this layer; ``0`` means this layer has no input; + ``1`` means this layer has one output; ``in0`` is the output name of + this layer. + +We need to add 1 extra line and also increment the number of layers. +The result looks like below: + +.. code-block:: bash + + 7767517 + 2029 2547 + SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31 + Input in0 0 1 in0 + +**Explanation** + + 1. ``7767517``, it is still the same + 2. ``2029 2547``, we have added an extra layer, so we need to update ``2028`` to ``2029``. + We don't need to change ``2547`` since the newly added layer has no inputs or outputs. + 3. ``SherpaMetaData sherpa_meta_data1 0 0 0=2 1=32 2=4 3=7 -23316=5,2,4,3,2,4 -23317=5,384,384,384,384,384 -23318=5,192,192,192,192,192 -23319=5,1,2,4,8,2 -23320=5,31,31,31,31,31`` + This line is newly added. Its explanation is given below: + + - ``SherpaMetaData`` is the type of this layer. Must be ``SherpaMetaData``. + - ``sherpa_meta_data1`` is the name of this layer. Must be ``sherpa_meta_data1``. + - ``0 0`` means this layer has no inputs or output. Must be ``0 0`` + - ``0=2``, 0 is the key and 2 is the value. MUST be ``0=2`` + - ``1=32``, 1 is the key and 32 is the value of the + parameter ``--decode-chunk-len`` that you provided when running + ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``2=4``, 2 is the key and 4 is the value of the + parameter ``--num-left-chunks`` that you provided when running + ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``3=7``, 3 is the key and 7 is the value of for the amount of padding + used in the Conv2DSubsampling layer. It should be 7 for zipformer + if you don't change zipformer.py. + - ``-23316=5,2,4,3,2,4``, attribute 16, this is an array attribute. + It is attribute 16 since -23300 - (-23316) = 16. + The first element of the array is the length of the array, which is 5 in our case. + ``2,4,3,2,4`` is the value of ``--num-encoder-layers``that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23317=5,384,384,384,384,384``, attribute 17. + The first element of the array is the length of the array, which is 5 in our case. + ``384,384,384,384,384`` is the value of ``--encoder-dims``that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23318=5,192,192,192,192,192``, attribute 18. + The first element of the array is the length of the array, which is 5 in our case. + ``192,192,192,192,192`` is the value of ``--attention-dims`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23319=5,1,2,4,8,2``, attribute 19. + The first element of the array is the length of the array, which is 5 in our case. + ``1,2,4,8,2`` is the value of ``--zipformer-downsampling-factors`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + - ``-23320=5,31,31,31,31,31``, attribute 20. + The first element of the array is the length of the array, which is 5 in our case. + ``31,31,31,31,31`` is the value of ``--cnn-module-kernels`` that you provided + when running ``./pruned_transducer_stateless7_streaming/export-for-ncnn.py``. + + For ease of reference, we list the key-value pairs that you need to add + in the following table. If your model has a different setting, please + change the values for ``SherpaMetaData`` accordingly. Otherwise, you + will be ``SAD``. + + +----------+--------------------------------------------+ + | key | value | + +==========+============================================+ + | 0 | 2 (fixed) | + +----------+--------------------------------------------+ + | 1 | ``-decode-chunk-len`` | + +----------+--------------------------------------------+ + | 2 | ``--num-left-chunks`` | + +----------+--------------------------------------------+ + | 3 | 7 (if you don't change code) | + +----------+--------------------------------------------+ + |-23316 | ``--num-encoder-layer`` | + +----------+--------------------------------------------+ + |-23317 | ``--encoder-dims`` | + +----------+--------------------------------------------+ + |-23318 | ``--attention-dims`` | + +----------+--------------------------------------------+ + |-23319 | ``--zipformer-downsampling-factors`` | + +----------+--------------------------------------------+ + |-23320 | ``--cnn-module-kernels`` | + +----------+--------------------------------------------+ + + 4. ``Input in0 0 1 in0``. No need to change it. + +.. caution:: + + When you add a new layer ``SherpaMetaData``, please remember to update the + number of layers. In our case, update ``2028`` to ``2029``. Otherwise, + you will be SAD later. + +.. hint:: + + After adding the new layer ``SherpaMetaData``, you cannot use this model + with ``streaming-ncnn-decode.py`` anymore since ``SherpaMetaData`` is + supported only in `sherpa-ncnn`_. + +.. hint:: + + `ncnn`_ is very flexible. You can add new layers to it just by text-editing + the ``param`` file! You don't need to change the ``bin`` file. + +Now you can use this model in `sherpa-ncnn`_. +Please refer to the following documentation: + + - Linux/macOS/Windows/arm/aarch64: ``_ + - ``Android``: ``_ + - ``iOS``: ``_ + - Python: ``_ + +We have a list of pre-trained models that have been exported for `sherpa-ncnn`_: + + - ``_ + + You can find more usages there. diff --git a/docs/source/model-export/export-ncnn.rst b/docs/source/model-export/export-ncnn.rst index 841d1d4de..9eb5f85d2 100644 --- a/docs/source/model-export/export-ncnn.rst +++ b/docs/source/model-export/export-ncnn.rst @@ -21,6 +21,7 @@ It has been tested on the following platforms: - ``iOS`` - ``Raspberry Pi`` - `爱芯派 `_ (`MAIX-III AXera-Pi `_). + - `RV1126 `_ `sherpa-ncnn`_ is self-contained and can be statically linked to produce a binary containing everything needed. Please refer @@ -31,5 +32,6 @@ to its documentation for details: .. toctree:: + export-ncnn-zipformer export-ncnn-conv-emformer export-ncnn-lstm diff --git a/docs/source/model-export/export-onnx.rst b/docs/source/model-export/export-onnx.rst index 8f0cb11fb..aa77204cb 100644 --- a/docs/source/model-export/export-onnx.rst +++ b/docs/source/model-export/export-onnx.rst @@ -9,6 +9,22 @@ to export trained models to `ONNX`_. There is also a file named ``onnx_pretrained.py``, which you can use the exported `ONNX`_ model in Python with `onnxruntime`_ to decode sound files. +sherpa-onnx +----------- + +We have a separate repository `sherpa-onnx`_ for deploying your exported models +on various platforms such as: + + - iOS + - Android + - Raspberry Pi + - Linux/macOS/Windows + + +Please see the documentation of `sherpa-onnx`_ for details: + + ``_ + Example ------- From 07243d136a2aa42c71eda7a7f9ada10a07e82662 Mon Sep 17 00:00:00 2001 From: pehonnet Date: Wed, 8 Mar 2023 14:06:07 +0100 Subject: [PATCH 2/9] remove key from result filename (#936) Co-authored-by: pe-honnet --- .../ASR/pruned_transducer_stateless2/decode.py | 6 +++--- egs/aishell/ASR/pruned_transducer_stateless2/decode.py | 6 +++--- egs/aishell/ASR/pruned_transducer_stateless3/decode.py | 6 +++--- egs/aishell/ASR/transducer_stateless/decode.py | 6 +++--- egs/aishell/ASR/transducer_stateless_modified-2/decode.py | 6 +++--- egs/aishell/ASR/transducer_stateless_modified/decode.py | 6 +++--- egs/aishell2/ASR/pruned_transducer_stateless5/decode.py | 6 +++--- egs/aishell4/ASR/pruned_transducer_stateless5/decode.py | 6 +++--- egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py | 6 +++--- .../ASR_v2/pruned_transducer_stateless7/decode.py | 6 +++--- egs/ami/ASR/pruned_transducer_stateless7/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_streaming/decode.py | 6 +++--- egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py | 6 +++--- egs/librispeech/ASR/conformer_ctc3/decode.py | 8 ++++---- .../ASR/conv_emformer_transducer_stateless/decode.py | 6 +++--- .../streaming_decode.py | 6 +++--- .../ASR/conv_emformer_transducer_stateless2/decode.py | 6 +++--- .../streaming_decode.py | 6 +++--- egs/librispeech/ASR/lstm_transducer_stateless/decode.py | 6 +++--- .../ASR/lstm_transducer_stateless/streaming_decode.py | 6 +++--- egs/librispeech/ASR/lstm_transducer_stateless2/decode.py | 6 +++--- egs/librispeech/ASR/lstm_transducer_stateless3/decode.py | 8 ++++---- .../ASR/lstm_transducer_stateless3/streaming_decode.py | 6 +++--- egs/librispeech/ASR/pruned2_knowledge/decode.py | 6 +++--- .../ASR/pruned_stateless_emformer_rnnt2/decode.py | 6 +++--- egs/librispeech/ASR/pruned_transducer_stateless/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless2/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless2/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless3/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless3/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless4/decode.py | 8 ++++---- .../ASR/pruned_transducer_stateless4/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless6/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc/ctc_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_ctc_bs/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7_streaming/decode.py | 6 +++--- .../streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless8/decode.py | 6 +++--- egs/librispeech/ASR/transducer/decode.py | 6 +++--- egs/librispeech/ASR/transducer_lstm/decode.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/decode.py | 6 +++--- egs/librispeech/ASR/transducer_stateless2/decode.py | 6 +++--- .../ASR/transducer_stateless_multi_datasets/decode.py | 6 +++--- egs/librispeech/ASR/zipformer_mmi/decode.py | 6 +++--- egs/mgb2/ASR/pruned_transducer_stateless5/decode.py | 6 +++--- egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py | 8 ++++---- egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py | 6 +++--- egs/tedlium3/ASR/pruned_transducer_stateless/decode.py | 6 +++--- egs/tedlium3/ASR/transducer_stateless/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless2/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/streaming_decode.py | 6 +++--- .../ASR/pruned_transducer_stateless5/decode.py | 6 +++--- .../ASR/pruned_transducer_stateless7/decode.py | 6 +++--- 60 files changed, 185 insertions(+), 185 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index d0f118959..090f7ff84 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -392,7 +392,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -401,7 +401,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -413,7 +413,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index 20a4f21c7..04888fbc1 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -389,7 +389,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -398,7 +398,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -414,7 +414,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index bac829ae1..6e97f338f 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -407,7 +407,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -416,7 +416,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -432,7 +432,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index e019d2329..d57fe6de4 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -326,7 +326,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -334,7 +334,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -350,7 +350,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index 41cc1c01c..743fc7f45 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -371,7 +371,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -380,7 +380,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -396,7 +396,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 7c06e6e51..9a1645915 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -375,7 +375,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -384,7 +384,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) # we compute CER for aishell dataset. results_char = [] @@ -400,7 +400,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tCER", file=f) diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index b5da0959b..80194ad12 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -544,7 +544,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -553,7 +553,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -565,7 +565,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index 37d766ec8..eb202f8a8 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -407,7 +407,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -416,7 +416,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -428,7 +428,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index e4a90ef71..675f0739f 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -392,7 +392,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -401,7 +401,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -413,7 +413,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py index 53381c1f4..9a7eef9bf 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py @@ -463,7 +463,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -472,7 +472,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -484,7 +484,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py index f47228fbe..fc4005325 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -479,7 +479,7 @@ def save_results( test_set_cers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -487,7 +487,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" ) with open(wers_filename, "w") as f: wer = write_error_stats( @@ -500,7 +500,7 @@ def save_results( for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" ) with open(cers_filename, "w") as f: cer = write_error_stats( @@ -513,7 +513,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py index 19d3c79c8..c5892f511 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -600,7 +600,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -610,7 +610,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -622,7 +622,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 8595c27bd..27ce41c87 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -400,7 +400,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = post_processing(results) results = sorted(results) @@ -410,7 +410,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -422,7 +422,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 6fbf9d674..cdee1ec9c 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -729,7 +729,7 @@ def save_results( test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) @@ -738,7 +738,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( @@ -755,7 +755,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -766,7 +766,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) delays_info = ( params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 365e8b8a7..5d241ccbf 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -433,7 +433,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -442,7 +442,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -454,7 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index c93125c80..e6c9d2ca2 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -751,7 +751,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") @@ -759,7 +759,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -771,7 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 78e1f4096..f9c1633d8 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -433,7 +433,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -442,7 +442,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -454,7 +454,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index b2cb2c96b..6b3c1b563 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -751,7 +751,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") @@ -759,7 +759,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -771,7 +771,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 3ad08f56a..6dc11bdb2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -567,7 +567,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -576,7 +576,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -588,7 +588,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index 961d8ddfb..d510d9659 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -743,7 +743,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") @@ -751,7 +751,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -763,7 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 78be9c01f..15e1109f2 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -703,7 +703,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -712,7 +712,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -724,7 +724,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a380bc470..7ac9d5f34 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -612,7 +612,7 @@ def save_results( test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) @@ -621,7 +621,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( @@ -634,7 +634,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -644,7 +644,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index 109746ed5..b8b6e4f43 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -743,7 +743,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") @@ -751,7 +751,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -763,7 +763,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index 40d14bb5a..f22731469 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -387,7 +387,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -395,7 +395,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -407,7 +407,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 0e3b7ff74..ea7692f49 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -421,7 +421,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -430,7 +430,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -442,7 +442,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 0444afe40..8a719ae3b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -586,7 +586,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -595,7 +595,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -607,7 +607,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index fbc39fb65..28c40c780 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -424,7 +424,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) # sort results so we can easily compare the difference between two # recognition results @@ -435,7 +435,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -447,7 +447,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 5f135f219..2791a60de 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -610,7 +610,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -619,7 +619,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -631,7 +631,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index bb08246d9..eac8f8393 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -426,7 +426,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) # sort results so we can easily compare the difference between two # recognition results @@ -437,7 +437,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -449,7 +449,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 109a94a69..298c6c950 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -870,7 +870,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -879,7 +879,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -891,7 +891,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 0e5111f33..421bfb0b7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -427,7 +427,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -436,7 +436,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -448,7 +448,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index c44db0206..dca2ec081 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -656,7 +656,7 @@ def save_results( test_set_delays = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) @@ -665,7 +665,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( @@ -678,7 +678,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -688,7 +688,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( params.res_dir - / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index c4e3cef16..cb5d52859 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -443,7 +443,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -452,7 +452,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -464,7 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 90b0fcf4b..5c5d3ecd9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -736,7 +736,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -745,7 +745,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -757,7 +757,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 064811f1c..ae221eaba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -443,7 +443,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -452,7 +452,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -464,7 +464,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index fd9de052a..c81186295 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -417,7 +417,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -426,7 +426,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -438,7 +438,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index b9bce465f..856ef845a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -723,7 +723,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -732,7 +732,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -744,7 +744,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 4b373e4c7..6c11d95b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -542,7 +542,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -551,7 +551,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) @@ -561,7 +561,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py index 32a9b6bb2..643486a6a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -594,7 +594,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -603,7 +603,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -615,7 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index f137485b2..aadf75c5f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -533,7 +533,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -542,7 +542,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) @@ -552,7 +552,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index ce45a4beb..77160a9d4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -594,7 +594,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -603,7 +603,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -615,7 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index aebe2b94b..ed499d043 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -569,7 +569,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -578,7 +578,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -590,7 +590,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 7a349ecb2..9191edaab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -410,7 +410,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -419,7 +419,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -431,7 +431,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index e61367134..8314d6acf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -595,7 +595,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -604,7 +604,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -616,7 +616,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index 804713a20..c0413e2d1 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -326,7 +326,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -335,7 +335,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -347,7 +347,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 9511ca6d7..cd6d722bd 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -323,7 +323,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -332,7 +332,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -344,7 +344,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 643238f1b..a72d60b9f 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -380,7 +380,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -389,7 +389,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -401,7 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index 9a6363629..c91a1f490 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -380,7 +380,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -389,7 +389,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -401,7 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 56ad558c6..5c20e2bfd 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -381,7 +381,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -390,7 +390,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -402,7 +402,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index 7d0ea78bb..a96c5c6f0 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -472,7 +472,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -481,7 +481,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) @@ -491,7 +491,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py index 1463f8f67..f72d4d7f6 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py @@ -411,7 +411,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") @@ -419,7 +419,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -431,7 +431,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index 219c96d60..cb9417d2a 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -392,7 +392,7 @@ def save_results( test_set_cers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -401,7 +401,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" ) with open(wers_filename, "w") as f: wer = write_error_stats( @@ -414,7 +414,7 @@ def save_results( for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" ) with open(cers_filename, "w") as f: cer = write_error_stats( @@ -427,7 +427,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index bf91fef7e..1d6a22973 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -510,7 +510,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -519,7 +519,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -531,7 +531,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 38f2ae83c..0d1fe9aa1 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -380,7 +380,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -389,7 +389,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -401,7 +401,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index 01f08ce59..c88760854 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -355,7 +355,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -364,7 +364,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -376,7 +376,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 04602ea2e..a0bf77b39 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -517,7 +517,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -526,7 +526,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -538,7 +538,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 7bd1177bd..9f6043926 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -490,7 +490,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -499,7 +499,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -511,7 +511,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index c7863415b..398690d48 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -467,7 +467,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) # sort results so we can easily compare the difference between two # recognition results @@ -478,7 +478,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -490,7 +490,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py index 6a67e26f8..5b7f5f95b 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py @@ -702,7 +702,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -711,7 +711,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -723,7 +723,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py index ace792e13..a291bb303 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py @@ -594,7 +594,7 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -603,7 +603,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -615,7 +615,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) From f5de2e90c6672a843d5e94166fbd60f339cb6b9b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 8 Mar 2023 22:56:04 +0800 Subject: [PATCH 3/9] Fix style issues. (#937) --- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless3/decode.py | 12 +++--------- egs/aishell/ASR/transducer_stateless/decode.py | 12 +++--------- .../transducer_stateless_modified-2/decode.py | 12 +++--------- .../ASR/transducer_stateless_modified/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../pruned_transducer_stateless7/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless7/decode.py | 16 ++++------------ .../decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- egs/librispeech/ASR/conformer_ctc3/decode.py | 15 ++++----------- .../conv_emformer_transducer_stateless/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/lstm_transducer_stateless/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/lstm_transducer_stateless2/decode.py | 12 +++--------- .../ASR/lstm_transducer_stateless3/decode.py | 15 ++++----------- .../streaming_decode.py | 12 +++--------- egs/librispeech/ASR/pruned2_knowledge/decode.py | 12 +++--------- .../pruned_stateless_emformer_rnnt2/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless3/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless4/decode.py | 15 ++++----------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless6/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless7/decode.py | 12 +++--------- .../ctc_decode.py | 12 +++--------- .../pruned_transducer_stateless7_ctc/decode.py | 12 +++--------- .../ctc_decode.py | 12 +++--------- .../decode.py | 12 +++--------- .../decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless8/decode.py | 12 +++--------- egs/librispeech/ASR/transducer/decode.py | 12 +++--------- egs/librispeech/ASR/transducer_lstm/decode.py | 12 +++--------- .../ASR/transducer_stateless/decode.py | 12 +++--------- .../ASR/transducer_stateless2/decode.py | 12 +++--------- .../decode.py | 12 +++--------- egs/librispeech/ASR/zipformer_mmi/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 16 ++++------------ .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless/decode.py | 12 +++--------- egs/tedlium3/ASR/transducer_stateless/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless2/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../streaming_decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless5/decode.py | 12 +++--------- .../ASR/pruned_transducer_stateless7/decode.py | 12 +++--------- 60 files changed, 185 insertions(+), 552 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index 090f7ff84..2512f233f 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -412,9 +408,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py index 04888fbc1..fb6c7c481 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/decode.py @@ -388,18 +388,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -413,9 +409,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 6e97f338f..954d9dc7e 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -406,18 +406,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -431,9 +427,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index d57fe6de4..d23f4f883 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -325,17 +325,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -349,9 +345,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index 743fc7f45..d164b6890 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -370,18 +370,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -395,9 +391,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 9a1645915..0a7d87fe8 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -374,18 +374,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" # we compute CER for aishell dataset. results_char = [] for res in results: @@ -399,9 +395,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tCER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index 80194ad12..9e44b4e34 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -543,18 +543,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -564,9 +560,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py index eb202f8a8..068e2749a 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/decode.py @@ -406,18 +406,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -427,9 +423,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py index 675f0739f..6c170c392 100755 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -412,9 +408,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py index 9a7eef9bf..2741e0eeb 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/decode.py @@ -462,18 +462,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -483,9 +479,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py index fc4005325..9999894d1 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -478,17 +478,13 @@ def save_results( test_set_wers = dict() test_set_cers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" - ) + wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" with open(wers_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -499,9 +495,7 @@ def save_results( results_char = [] for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" - ) + cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" with open(cers_filename, "w") as f: cer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -512,9 +506,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) for key in test_set_wers: diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py index c5892f511..f5a1d750d 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -599,9 +599,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -609,9 +607,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -621,9 +617,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 27ce41c87..ee694a9e0 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -399,9 +399,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = post_processing(results) results = sorted(results) store_transcripts(filename=recog_path, texts=results) @@ -409,9 +407,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -421,9 +417,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index cdee1ec9c..e6327bb5e 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -728,18 +728,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, @@ -754,9 +750,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -765,8 +759,7 @@ def save_results( # sort according to the mean start symbol delay test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index 5d241ccbf..7be3299f3 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -432,18 +432,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -453,9 +449,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index e6c9d2ca2..e5a7c7116 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -750,17 +750,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -770,9 +766,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index f9c1633d8..d022d463e 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -432,18 +432,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -453,9 +449,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 6b3c1b563..f5d894a7b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -750,17 +750,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -770,9 +766,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py index 6dc11bdb2..856c9d945 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -566,18 +566,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -587,9 +583,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py index d510d9659..f989d9bc0 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -742,17 +742,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -762,9 +758,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 15e1109f2..6c58a57e1 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -702,18 +702,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -723,9 +719,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index 7ac9d5f34..a2b4f9e1a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -611,18 +611,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True @@ -633,9 +629,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -643,8 +637,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index b8b6e4f43..c737e3611 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -742,17 +742,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=sorted(results)) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -762,9 +758,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index f22731469..82fd103ea 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -386,17 +386,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -406,9 +402,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index ea7692f49..072d49d9c 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -420,18 +420,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -441,9 +437,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 8a719ae3b..6dfe11cee 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -585,18 +585,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -606,9 +602,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index 28c40c780..f4b01fd06 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -423,9 +423,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -434,9 +432,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -446,9 +442,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 2791a60de..172c9ab7c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -609,18 +609,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -630,9 +626,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index eac8f8393..9c4a13606 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -425,9 +425,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -436,9 +434,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -448,9 +444,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 298c6c950..aa055049e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -869,18 +869,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -890,9 +886,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 421bfb0b7..3a1ecb7ed 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -426,18 +426,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -447,9 +443,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index dca2ec081..5ec3d3b45 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -655,18 +655,14 @@ def save_results( test_set_wers = dict() test_set_delays = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts_and_timestamps(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( f, f"{test_set_name}-{key}", results, enable_log=True @@ -677,9 +673,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -687,8 +681,7 @@ def save_results( test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) delays_info = ( - params.res_dir - / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" + params.res_dir / f"symbol-delay-summary-{test_set_name}-{params.suffix}.txt" ) with open(delays_info, "w") as f: print("settings\tsymbol-delay", file=f) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index cb5d52859..ca3a023ce 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -442,18 +442,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -463,9 +459,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 5c5d3ecd9..2be895feb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -735,18 +735,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -756,9 +752,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index ae221eaba..5b15dcee7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -442,18 +442,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -463,9 +459,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index c81186295..95534efef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -416,18 +416,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -437,9 +433,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 856ef845a..32b3134b9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -722,18 +722,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -743,9 +739,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py index 6c11d95b4..629bec058 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -541,18 +541,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -560,9 +556,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py index 643486a6a..7641fa5af 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py index aadf75c5f..fa7144f0f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/ctc_decode.py @@ -532,18 +532,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -551,9 +547,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py index 77160a9d4..e497787d3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index ed499d043..e7616fbc5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -568,18 +568,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -589,9 +585,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py index 9191edaab..c272ed641 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py @@ -409,18 +409,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -430,9 +426,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 8314d6acf..7b651a632 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -594,18 +594,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -615,9 +611,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index c0413e2d1..8d379d1fa 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -325,18 +325,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -346,9 +342,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index cd6d722bd..806b68f40 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -322,18 +322,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -343,9 +339,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index a72d60b9f..42125e19f 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index c91a1f490..b05fe2a4d 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 5c20e2bfd..5570b30ae 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -380,18 +380,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -401,9 +397,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/librispeech/ASR/zipformer_mmi/decode.py b/egs/librispeech/ASR/zipformer_mmi/decode.py index a96c5c6f0..33c0bf199 100755 --- a/egs/librispeech/ASR/zipformer_mmi/decode.py +++ b/egs/librispeech/ASR/zipformer_mmi/decode.py @@ -471,18 +471,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats(f, f"{test_set_name}-{key}", results) test_set_wers[key] = wer @@ -490,9 +486,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py index f72d4d7f6..72338bade 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py @@ -410,17 +410,13 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -430,9 +426,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index cb9417d2a..4434aae62 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -391,18 +391,14 @@ def save_results( test_set_wers = dict() test_set_cers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - wers_filename = ( - params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" - ) + wers_filename = params.res_dir / f"wers-{test_set_name}-{params.suffix}.txt" with open(wers_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -413,9 +409,7 @@ def save_results( results_char = [] for res in results: results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - cers_filename = ( - params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" - ) + cers_filename = params.res_dir / f"cers-{test_set_name}-{params.suffix}.txt" with open(cers_filename, "w") as f: cer = write_error_stats( f, f"{test_set_name}-{key}", results_char, enable_log=True @@ -426,9 +420,7 @@ def save_results( test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) for key in test_set_wers: diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index 1d6a22973..3bfb832fb 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -509,18 +509,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -530,9 +526,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 0d1fe9aa1..abba9d403 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -379,18 +379,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -400,9 +396,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index c88760854..fb0e3116b 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -354,18 +354,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -375,9 +371,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index a0bf77b39..823b33ae5 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -516,18 +516,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -537,9 +533,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 9f6043926..32d5738b1 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -489,18 +489,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -510,9 +506,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 398690d48..3a4dc3cb8 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -466,9 +466,7 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" # sort results so we can easily compare the difference between two # recognition results results = sorted(results) @@ -477,9 +475,7 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -489,9 +485,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py index 5b7f5f95b..b77f734e3 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py @@ -701,18 +701,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -722,9 +718,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py index a291bb303..e334e690a 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py @@ -593,18 +593,14 @@ def save_results( ): test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - ) + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True @@ -614,9 +610,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: From 28af269e5e27cb8ab62f1bc82d1c5a2b7f659843 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 9 Mar 2023 17:38:15 +0800 Subject: [PATCH 4/9] Fix for workflow (#934) --- ...k-librispeech-test-clean-and-test-other.sh | 2 +- ...pruned-transducer-stateless3-2022-06-20.sh | 1 - ...n-librispeech-conformer-ctc3-2022-11-28.sh | 1 - ...h-lstm-transducer-stateless2-2022-09-03.sh | 1 - ...-pruned-transducer-stateless-2022-03-12.sh | 1 - ...pruned-transducer-stateless2-2022-04-29.sh | 1 - ...pruned-transducer-stateless3-2022-04-29.sh | 1 - ...pruned-transducer-stateless3-2022-05-13.sh | 1 - ...pruned-transducer-stateless5-2022-05-13.sh | 1 - ...pruned-transducer-stateless7-2022-11-11.sh | 1 - ...ed-transducer-stateless7-ctc-2022-12-01.sh | 3 +- ...transducer-stateless7-ctc-bs-2022-12-15.sh | 3 +- ...nsducer-stateless7-streaming-2022-12-29.sh | 1 - ...pruned-transducer-stateless8-2022-11-14.sh | 1 - ...pruned-transducer-stateless2-2022-06-26.sh | 1 - ...speech-transducer-stateless2-2022-04-19.sh | 1 - ...un-librispeech-zipformer-mmi-2022-12-08.sh | 1 - .../scripts/run-pre-trained-conformer-ctc.sh | 1 - ...d-transducer-stateless-librispeech-100h.sh | 1 - ...d-transducer-stateless-librispeech-960h.sh | 1 - ...transducer-stateless-modified-2-aishell.sh | 1 - ...d-transducer-stateless-modified-aishell.sh | 1 - .../run-pre-trained-transducer-stateless.sh | 1 - .github/scripts/run-pre-trained-transducer.sh | 1 - ...enetspeech-pruned-transducer-stateless2.sh | 1 - .github/scripts/test-ncnn-export.sh | 67 ------------------- .github/workflows/run-aishell-2022-06-20.yml | 4 +- .../workflows/run-gigaspeech-2022-05-13.yml | 2 +- .../workflows/run-librispeech-2022-03-12.yml | 4 +- .../workflows/run-librispeech-2022-04-29.yml | 4 +- .../workflows/run-librispeech-2022-05-13.yml | 4 +- .../run-librispeech-2022-11-11-stateless7.yml | 4 +- .../run-librispeech-2022-11-14-stateless8.yml | 4 +- ...-librispeech-2022-12-01-stateless7-ctc.yml | 4 +- ...n-librispeech-2022-12-08-zipformer-mmi.yml | 4 +- ...brispeech-2022-12-15-stateless7-ctc-bs.yml | 6 +- ...speech-2022-12-29-stateless7-streaming.yml | 4 +- ...-librispeech-conformer-ctc3-2022-11-28.yml | 4 +- ...-lstm-transducer-stateless2-2022-09-03.yml | 4 +- ...runed-transducer-stateless3-2022-05-13.yml | 4 +- ...aming-transducer-stateless2-2022-06-26.yml | 4 +- ...peech-transducer-stateless2-2022-04-19.yml | 4 +- .../run-pretrained-conformer-ctc.yml | 4 +- ...-transducer-stateless-librispeech-100h.yml | 4 +- ...r-stateless-librispeech-multi-datasets.yml | 4 +- ...ransducer-stateless-modified-2-aishell.yml | 4 +- ...-transducer-stateless-modified-aishell.yml | 4 +- .../run-pretrained-transducer-stateless.yml | 4 +- .../workflows/run-pretrained-transducer.yml | 4 +- .github/workflows/run-ptb-rnn-lm.yml | 2 +- ...netspeech-pruned-transducer-stateless2.yml | 4 +- .github/workflows/run-yesno-recipe.yml | 2 +- .github/workflows/test-ncnn-export.yml | 2 +- .github/workflows/test-onnx-export.yml | 2 +- .github/workflows/test.yml | 4 +- .../ASR/local/compute_fbank_librispeech.py | 36 +++++++--- 56 files changed, 82 insertions(+), 159 deletions(-) diff --git a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh index bb7c7dfdc..0bec8c0c4 100755 --- a/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh +++ b/.github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh @@ -15,5 +15,5 @@ mkdir -p data cd data [ ! -e fbank ] && ln -s ~/tmp/fbank-libri fbank cd .. -./local/compute_fbank_librispeech.py +./local/compute_fbank_librispeech.py --dataset 'test-clean test-other' ls -lh data/fbank/ diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index e70a1848d..4c393f6be 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -25,7 +25,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh index df29f188e..c68ccc954 100755 --- a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index 91cdea01a..4cd2c4bec 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -20,7 +20,6 @@ abs_repo=$(realpath $repo) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh index dafea56db..6792c7088 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless-2022-03-12.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index c3d07dc0e..dbf678d72 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -23,7 +23,6 @@ popd log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh index 22de3b45d..b6d477afe 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-04-29.sh @@ -22,7 +22,6 @@ popd log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh index ceb77c7c3..efa4b53f0 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless3-2022-05-13.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index c6a781318..511fe0c9e 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh index 8e485d2e6..2bc179c86 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-2022-11-11.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh index 3cbb480f6..192438353 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp @@ -148,4 +147,4 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == done rm pruned_transducer_stateless7_ctc/exp/*.pt -fi \ No newline at end of file +fi diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh index ed66a728e..761eb72e2 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh @@ -10,7 +10,7 @@ log() { cd egs/librispeech/ASR -repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2022-12-14 +repo_url=https://huggingface.co/yfyeung/icefall-asr-librispeech-pruned_transducer_stateless7_ctc_bs-2023-01-29 log "Downloading pre-trained model from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh index 584f5d488..e1e4e1f10 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-streaming-2022-12-29.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh index e782b8425..5d9485692 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless8-2022-11-14.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh index af37102d5..77cd59506 100755 --- a/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh +++ b/.github/scripts/run-librispeech-streaming-pruned-transducer-stateless2-2022-06-26.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index 5b8ed396b..b4aca1b6b 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh index 77f28b054..a58b8ec56 100755 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -18,7 +18,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 96c320616..125d1f3b1 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.flac ls -lh $repo/test_wavs/*.flac log "CTC decoding" diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 209d4814f..89115e88d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 34ff76fe4..85e2c89e6 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh index 75650c2d3..0644d9be0 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-2-aishell.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh index bcc2d74cb..79fb64311 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-modified-aishell.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index d3e40315a..41456f11b 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav for sym in 1 2 3; do diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index cfa006776..1331c966c 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -19,7 +19,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav log "Beam search decoding" diff --git a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh index 2d237dcf2..90097c752 100755 --- a/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh +++ b/.github/scripts/run-wenetspeech-pruned-transducer-stateless2.sh @@ -20,7 +20,6 @@ repo=$(basename $repo_url) log "Display test files" tree $repo/ -soxi $repo/test_wavs/*.wav ls -lh $repo/test_wavs/*.wav pushd $repo/exp diff --git a/.github/scripts/test-ncnn-export.sh b/.github/scripts/test-ncnn-export.sh index 9f5df2d58..52491d2ea 100755 --- a/.github/scripts/test-ncnn-export.sh +++ b/.github/scripts/test-ncnn-export.sh @@ -232,70 +232,3 @@ python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ rm -rf $repo log "--------------------------------------------------------------------------" - -# Go back to the root directory of icefall repo -popd - -pushd egs/csj/ASR - -log "==========================================================================" -repo_url=https://huggingface.co/TeoWenShen/icefall-asr-csj-pruned-transducer-stateless7-streaming-230208 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp_fluent/pretrained.pt" -git lfs pull --include "exp_disfluent/pretrained.pt" - -cd exp_fluent -ln -s pretrained.pt epoch-99.pt - -cd ../exp_disfluent -ln -s pretrained.pt epoch-99.pt - -cd ../test_wavs -git lfs pull --include "*.wav" -popd - -log "Export via torch.jit.trace()" - -for exp in exp_fluent exp_disfluent; do - ./pruned_transducer_stateless7_streaming/export-for-ncnn.py \ - --exp-dir $repo/$exp/ \ - --lang $repo/data/lang_char \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - \ - --decode-chunk-len 32 \ - --num-left-chunks 4 \ - --num-encoder-layers "2,4,3,2,4" \ - --feedforward-dims "1024,1024,2048,2048,1024" \ - --nhead "8,8,8,8,8" \ - --encoder-dims "384,384,384,384,384" \ - --attention-dims "192,192,192,192,192" \ - --encoder-unmasked-dims "256,256,256,256,256" \ - --zipformer-downsampling-factors "1,2,4,8,2" \ - --cnn-module-kernels "31,31,31,31,31" \ - --decoder-dim 512 \ - --joiner-dim 512 - - pnnx $repo/$exp/encoder_jit_trace-pnnx.pt - pnnx $repo/$exp/decoder_jit_trace-pnnx.pt - pnnx $repo/$exp/joiner_jit_trace-pnnx.pt - - for wav in aps-smp.wav interview_aps-smp.wav reproduction-smp.wav sps-smp.wav; do - python3 ./pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ - --tokens $repo/data/lang_char/tokens.txt \ - --encoder-param-filename $repo/$exp/encoder_jit_trace-pnnx.ncnn.param \ - --encoder-bin-filename $repo/$exp/encoder_jit_trace-pnnx.ncnn.bin \ - --decoder-param-filename $repo/$exp/decoder_jit_trace-pnnx.ncnn.param \ - --decoder-bin-filename $repo/$exp/decoder_jit_trace-pnnx.ncnn.bin \ - --joiner-param-filename $repo/$exp/joiner_jit_trace-pnnx.ncnn.param \ - --joiner-bin-filename $repo/$exp/joiner_jit_trace-pnnx.ncnn.bin \ - $repo/test_wavs/$wav - done -done - -rm -rf $repo -log "--------------------------------------------------------------------------" diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml index 1865a0da8..f5ba73195 100644 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ b/.github/workflows/run-aishell-2022-06-20.yml @@ -65,7 +65,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -87,7 +87,7 @@ jobs: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index e438c5dba..c7b9cc79d 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml index 3ba6850cd..9c7cd1228 100644 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml index 595b410b8..78c9e759f 100644 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ b/.github/workflows/run-librispeech-2022-04-29.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml index eb0b06a2d..04799bf52 100644 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ b/.github/workflows/run-librispeech-2022-05-13.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 365e2761a..6dfc23920 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml index acb11a8f4..0544e68b3 100644 --- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml +++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml index ccd8d50d0..62e1f2a01 100644 --- a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml +++ b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml @@ -60,7 +60,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +119,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml index 5472ca59b..7dc33aaa9 100644 --- a/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml +++ b/.github/workflows/run-librispeech-2022-12-08-zipformer-mmi.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml index 6e2b40cf3..de55847ad 100644 --- a/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml +++ b/.github/workflows/run-librispeech-2022-12-15-stateless7-ctc-bs.yml @@ -35,7 +35,7 @@ on: jobs: run_librispeech_2022_12_15_zipformer_ctc_bs: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -60,7 +60,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -119,7 +119,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml index 6dd93946a..feb5c6fd0 100644 --- a/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml +++ b/.github/workflows/run-librispeech-2022-12-29-stateless7-streaming.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml index d763fb1c5..c95ed8b9a 100644 --- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml +++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index f737d9a25..e14d4e92f 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -47,7 +47,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -106,7 +106,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index f67f7599b..73d91fcd4 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml index ac7e58b20..8a690393e 100644 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml index 575727e22..217dbdfa1 100644 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml @@ -64,7 +64,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -123,7 +123,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index 7dbfd2bd9..4e8e7b8db 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml index d6b3de8d4..ddde4f1d6 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml index 749fb3fca..00ea97b2a 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 92bf6feb8..b3cfc9efd 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml index e51da8bd8..ab598541d 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 2103d0510..d663d49dd 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -63,7 +63,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -122,7 +122,7 @@ jobs: ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank ls -lh egs/librispeech/ASR/data/* - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml index 902319b55..9cb9d3b59 100644 --- a/.github/workflows/run-pretrained-transducer.yml +++ b/.github/workflows/run-pretrained-transducer.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -73,7 +73,7 @@ jobs: - name: Inference with pre-trained model shell: bash run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml index 47ed958f2..f8d9c02c5 100644 --- a/.github/workflows/run-ptb-rnn-lm.yml +++ b/.github/workflows/run-ptb-rnn-lm.yml @@ -47,7 +47,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Prepare data shell: bash diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml index 8a7be0b80..14fb96ec8 100644 --- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml +++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml @@ -54,7 +54,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache @@ -76,7 +76,7 @@ jobs: GITHUB_EVENT_NAME: ${{ github.event_name }} GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} run: | - sudo apt-get -qq install git-lfs tree sox + sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index ed343aee5..1187dbf38 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -67,7 +67,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Run yesno recipe shell: bash diff --git a/.github/workflows/test-ncnn-export.yml b/.github/workflows/test-ncnn-export.yml index e10cfe76b..cdea54854 100644 --- a/.github/workflows/test-ncnn-export.yml +++ b/.github/workflows/test-ncnn-export.yml @@ -46,7 +46,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml index c7729dedb..3dc4261ab 100644 --- a/.github/workflows/test-onnx-export.yml +++ b/.github/workflows/test-onnx-export.yml @@ -46,7 +46,7 @@ jobs: run: | grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* - name: Cache kaldifeat id: my-cache diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c062a2a3d..0da4f6b4b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -56,7 +56,7 @@ jobs: run: | sudo apt update sudo apt install -q -y libsndfile1-dev libsndfile1 ffmpeg - sudo apt install -q -y --fix-missing sox libsox-dev libsox-fmt-all + sudo apt install -q -y --fix-missing libsox-dev libsox-fmt-all - name: Install Python dependencies run: | @@ -70,7 +70,7 @@ jobs: pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements pip uninstall -y protobuf - pip install --no-binary protobuf protobuf + pip install --no-binary protobuf protobuf==3.20.* pip install kaldifst pip install onnxruntime diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 9f8503814..745eaf1e8 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -54,10 +54,20 @@ def get_args(): help="""Path to the bpe.model. If not None, we will remove short and long utterances before extracting features""", ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + return parser.parse_args() -def compute_fbank_librispeech(bpe_model: Optional[str] = None): +def compute_fbank_librispeech( + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, +): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -68,15 +78,19 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): sp = spm.SentencePieceProcessor() sp.load(bpe_model) - dataset_parts = ( - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-clean-100", - "train-clean-360", - "train-other-500", - ) + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + prefix = "librispeech" suffix = "jsonl.gz" manifests = read_manifests_if_cached( @@ -131,4 +145,4 @@ if __name__ == "__main__": logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() logging.info(vars(args)) - compute_fbank_librispeech(bpe_model=args.bpe_model) + compute_fbank_librispeech(bpe_model=args.bpe_model, dataset=args.dataset) From 9ddd811925534dc47b183a23429a4727c6416e81 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Fri, 10 Mar 2023 14:37:28 +0800 Subject: [PATCH 5/9] Fix padding_idx (#942) * fix padding_idx * update RESULTS.md --- egs/librispeech/ASR/RESULTS.md | 4 ++++ egs/librispeech/ASR/pruned_transducer_stateless/decoder.py | 1 - egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py | 1 - egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py | 1 - 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index ecb84eb01..9ca7a19b8 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -540,6 +540,10 @@ for m in greedy_search fast_beam_search modified_beam_search ; do done ``` +Note that a small change is made to the `pruned_transducer_stateless7/decoder.py` in +this [PR](/ceph-data4/yangxiaoyu/softwares/icefall_development/icefall_random_padding/egs/librispeech/ASR/pruned_transducer_stateless7/exp_960h_no_paddingidx_ngpu4/tensorboard) to address the +problem of emitting the first symbol at the very beginning. If you need a +model without this issue, please download the model from here: ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 72593173c..49b82c433 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -58,7 +58,6 @@ class Decoder(nn.Module): self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, - padding_idx=blank_id, ) self.blank_id = blank_id self.unk_id = unk_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b59928103..d44ed6f81 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -59,7 +59,6 @@ class Decoder(nn.Module): self.embedding = ScaledEmbedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, - padding_idx=blank_id, ) self.blank_id = blank_id diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 384b78524..b085a1817 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -56,7 +56,6 @@ class Decoder(nn.Module): self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, - padding_idx=blank_id, ) self.blank_id = blank_id From cad6735e0739f149ba3f452e52a948da946527dc Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 10 Mar 2023 19:28:59 +0800 Subject: [PATCH 6/9] Modify make_pad_mask to support TensorRT (#943) * Modify make_pad_mask to support TensorRT * Fix for test --- egs/librispeech/ASR/transducer/test_rnn.py | 10 +++++----- icefall/utils.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index 74c94cc70..d8effb996 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -432,11 +432,11 @@ def test_layernorm_lstm_forward(device="cpu"): def test_layernorm_lstm_with_projection_forward(device="cpu"): - input_size = torch.randint(low=2, high=100, size=(1,)).item() - hidden_size = torch.randint(low=10, high=100, size=(1,)).item() - proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() - num_layers = torch.randint(low=2, high=100, size=(1,)).item() - bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + input_size = 40 # torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = 40 # torch.randint(low=10, high=100, size=(1,)).item() + proj_size = 20 # torch.randint(low=2, high=hidden_size, size=(1,)).item() + num_layers = 12 # torch.randint(low=2, high=100, size=(1,)).item() + bias = True # torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 self_lstm = LayerNormLSTM( input_size=input_size, diff --git a/icefall/utils.py b/icefall/utils.py index 2358ed02f..5d86472b5 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1095,10 +1095,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: assert lengths.ndim == 1, lengths.ndim max_len = max(max_len, lengths.max()) n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) - expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) - - return expaned_lengths >= lengths.unsqueeze(1) + return expaned_lengths >= lengths.unsqueeze(-1) # Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py From a48812ddb307069339e029942321b8c7417aed93 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 15 Mar 2023 22:02:20 +0800 Subject: [PATCH 7/9] Ban the test_rnn.py in ci-test (#949) --- .github/workflows/test.yml | 8 ++++---- egs/librispeech/ASR/transducer/test_rnn.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0da4f6b4b..079772e97 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -119,8 +119,8 @@ jobs: cd ../transducer_stateless pytest -v -s - cd ../transducer - pytest -v -s + # cd ../transducer + # pytest -v -s cd ../transducer_stateless2 pytest -v -s @@ -157,8 +157,8 @@ jobs: cd ../transducer_stateless pytest -v -s - cd ../transducer - pytest -v -s + # cd ../transducer + # pytest -v -s cd ../transducer_stateless2 pytest -v -s diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index d8effb996..74c94cc70 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -432,11 +432,11 @@ def test_layernorm_lstm_forward(device="cpu"): def test_layernorm_lstm_with_projection_forward(device="cpu"): - input_size = 40 # torch.randint(low=2, high=100, size=(1,)).item() - hidden_size = 40 # torch.randint(low=10, high=100, size=(1,)).item() - proj_size = 20 # torch.randint(low=2, high=hidden_size, size=(1,)).item() - num_layers = 12 # torch.randint(low=2, high=100, size=(1,)).item() - bias = True # torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 + input_size = torch.randint(low=2, high=100, size=(1,)).item() + hidden_size = torch.randint(low=10, high=100, size=(1,)).item() + proj_size = torch.randint(low=2, high=hidden_size, size=(1,)).item() + num_layers = torch.randint(low=2, high=100, size=(1,)).item() + bias = torch.randint(low=0, high=1000, size=(1,)).item() & 2 == 0 self_lstm = LayerNormLSTM( input_size=input_size, From 6196b4a407f0ff4359814c81c385eefb5636f04d Mon Sep 17 00:00:00 2001 From: Jason's Lab <563042811@qq.com> Date: Thu, 16 Mar 2023 09:52:11 +0800 Subject: [PATCH 8/9] Add char-based language model training process for aishell. (#945) * Add char-based language model training process for aishell. Add soft link from librispeech/ASR/local/sort_lm_training_data.py to aishell/ASR/local/ --------- Co-authored-by: lichao --- .../local/prepare_char_lm_training_data.py | 164 ++++++++++++++++++ egs/aishell/ASR/prepare.sh | 92 +++++++++- 2 files changed, 255 insertions(+), 1 deletion(-) create mode 100644 egs/aishell/ASR/local/prepare_char_lm_training_data.py diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py new file mode 100644 index 000000000..e7995680b --- /dev/null +++ b/egs/aishell/ASR/local/prepare_char_lm_training_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `tokens.txt` and a text file such as +./download/lm/aishell-transcript.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_char. The format is as follows: +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the same format with librispeech receipe +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-char", + type=str, + help="""Lang dir of asr model, e.g. data/lang_char""", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/lm/aishell-train-word.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/lm_training_char/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + # make token_dict from tokens.txt in order to map characters to tokens. + token_dict = {} + token_file = args.lang_char + "/tokens.txt" + + with open(token_file, "r") as f: + for line in f.readlines(): + line_list = line.split() + token_dict[line_list[0]] = int(line_list[1]) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of tokens. + word2index = dict() + + word2token = [] # Will be a list-of-list-of-int, representing tokens. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + if "aishell-lm" in args.lm_data: + num_lines_in_total = 120098.0 + step = 50000 + elif "valid" in args.lm_data: + num_lines_in_total = 14326.0 + step = 3000 + elif "test" in args.lm_data: + num_lines_in_total = 7176.0 + step = 3000 + else: + num_lines_in_total = None + step = None + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed / num_lines_in_total * 100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_token = [] + for t in w: + if t in token_dict: + w_token.append(token_dict[t]) + else: + w_token.append(token_dict[""]) + word2index[w] = len(word2token) + word2token.append(w_token) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2token) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} ({i / num_sentences * 100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 5917668a1..cf4ee7818 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -7,7 +7,7 @@ set -eou pipefail nj=15 stage=-1 -stop_stage=10 +stop_stage=11 # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -219,3 +219,93 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir $lang_phone_dir ./local/compile_hlg.py --lang-dir $lang_char_dir fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Generate LM training data" + + log "Processing char based data" + out_dir=data/lm_training_char + mkdir -p $out_dir $dl_dir/lm + + if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then + cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-train-word.txt \ + --lm-archive $out_dir/lm_data.pt + + if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid + find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-valid-word.txt \ + --lm-archive $out_dir/lm_data_valid.pt + + if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid + find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-test-word.txt \ + --lm-archive $out_dir/lm_data_test.pt +fi + + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of tokens + # in a sentence. + + out_dir=data/lm_training_char + mkdir -p $out_dir + ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 1 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 512 \ + --hidden-dim 512 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data data/lm_training_char/sorted_lm_data.pt \ + --lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12345 +fi From 7948624a220b9fc40dbfa87cb1eb83041af45ef3 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Fri, 17 Mar 2023 13:44:29 +0800 Subject: [PATCH 9/9] Support fine-tuning (#944) * support finetune * add files for decoding giga * support initializing modules * add a fine-tune bash script --- egs/librispeech/ASR/finetune.sh | 85 ++ .../decode_gigaspeech.py | 861 +++++++++++ .../pruned_transducer_stateless7/finetune.py | 1342 +++++++++++++++++ .../gigaspeech.py | 406 +++++ .../gigaspeech_scoring.py | 1 + .../ASR/pruned_transducer_stateless7/optim.py | 46 +- 6 files changed, 2739 insertions(+), 2 deletions(-) create mode 100755 egs/librispeech/ASR/finetune.sh create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py diff --git a/egs/librispeech/ASR/finetune.sh b/egs/librispeech/ASR/finetune.sh new file mode 100755 index 000000000..63d0966ed --- /dev/null +++ b/egs/librispeech/ASR/finetune.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +# This is an example script for fine-tuning. Here, we fine-tune a model trained +# on Librispeech on GigaSpeech. The model used for fine-tuning is +# pruned_transducer_stateless7 (zipformer). If you want to fine-tune model +# from another recipe, you can adapt ./pruned_transducer_stateless7/finetune.py +# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues. + +# We assume that you have already prepared the GigaSpeech manfiest&features under ./data. +# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/gigaspeech/ASR/prepare.sh. + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download Pre-trained model" + + # clone from huggingface + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Start fine-tuning" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + base_lr=0.005 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained.pt + export CUDA_VISIBLE_DEVICES="0,1" + + ./pruned_transducer_stateless7/finetune.py \ + --world-size 2 \ + --master-port 18180 \ + --num-epochs 20 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp_giga_finetune \ + --subset S \ + --use-fp16 1 \ + --base-lr $base_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \ + --do-finetune True \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 500 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decoding" + + epoch=15 + avg=10 + + for m in greedy_search modified_beam_search; do + python pruned_transducer_stateless7/decode_gigaspeech.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model True \ + --beam-size 4 \ + --exp-dir pruned_transducer_stateless7/exp_giga_finetune \ + --max-duration 400 \ + --decoding-method $m + done +fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py new file mode 100644 index 000000000..4f64850b6 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -0,0 +1,861 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn + +# from asr_datamodule import LibriSpeechAsrDataModule +from gigaspeech import GigaSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + """ + This scripts test a libri model with libri BPE + on Gigaspeech. + """ + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / (params.decoding_method + "_gigaspeech") + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + dev_cuts = gigaspeech.dev_cuts() + test_cuts = gigaspeech.test_cuts() + + dev_dl = gigaspeech.test_dataloaders(dev_cuts) + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [dev_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py new file mode 100755 index 000000000..726a24809 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -0,0 +1,1342 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from decoder import Decoder +from gigaspeech import GigaSpeechAsrDataModule +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + filter_uneven_sized_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--do-finetune", type=str2bool, default=False) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="""Embedding dimension in the 2 blocks of zipformer encoder + layers, comma separated + """, + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers,\ + comma separated; not the same as embedding dimension. + """, + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="""Unmasked dimensions in the encoders, relates to augmentation + during training. Must be <= each of encoder_dims. Empirically, less + than 256 seems to make performance worse. + """, + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="""Path to the BPE model. + This should be the bpe model of the original model + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.005, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "frame_shift_ms": 10.0, + "allowed_excess_duration_ratio": 0.1, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + # For the uneven-sized batch, the total duration after padding would possibly + # cause OOM. Hence, for each batch, which is sorted descendingly by length, + # we simply drop the last few shortest samples, so that the retained total frames + # (after padding) would not exceed `allowed_max_frames`: + # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`, + # where `max_frames = max_duration * 1000 // frame_shift_ms`. + # We set allowed_excess_duration_ratio=0.1. + max_frames = params.max_duration * 1000 // params.frame_shift_ms + allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio)) + batch = filter_uneven_sized_batch(batch, allowed_max_frames) + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have + # different behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + # load model parameters for model fine-tuning + if params.do_finetune: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + else: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + gigaspeech = GigaSpeechAsrDataModule(args) + + train_cuts = gigaspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.dev_cuts() + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments( + parser + ) # you may replace this with your own dataset + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py new file mode 100644 index 000000000..5c01d7190 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech.py @@ -0,0 +1,406 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class GigaSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + type=str, + default="XL", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--small-dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev (speeds up training)", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get train_{self.args.subset} cuts") + path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz" + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + if self.args.small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py new file mode 120000 index 000000000..fdfa6ce4b --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/gigaspeech_scoring.py @@ -0,0 +1 @@ +../../../gigaspeech/ASR/pruned_transducer_stateless2/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 374b78cb3..b84e518d0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -799,6 +799,47 @@ def _test_eden(): logging.info(f"state dict = {scheduler.state_dict()}") +def _plot_eden_lr(): + import matplotlib.pyplot as plt + + m = torch.nn.Linear(100, 100) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in m.named_parameters()] + ) + + for lr_epoch in [4, 10, 100]: + for lr_batch in [100, 400]: + optim = ScaledAdam( + m.parameters(), lr=0.03, parameters_names=parameters_names + ) + scheduler = Eden( + optim, lr_batches=lr_batch, lr_epochs=lr_epoch, verbose=True + ) + lr = [] + + for epoch in range(10): + scheduler.step_epoch(epoch) # sets epoch to `epoch` + + for step in range(500): + lr.append(scheduler.get_lr()) + + x = torch.randn(200, 100).detach() + x.requires_grad = True + y = m(x) + dy = torch.randn(200, 100).detach() + f = (y * dy).sum() + f.backward() + + optim.step() + scheduler.step_batch() + optim.zero_grad() + plt.plot(lr, label=f"lr_epoch:{lr_epoch}, lr_batch:{lr_batch}") + + plt.legend() + plt.savefig("lr.png") + + # This is included mostly as a baseline for ScaledAdam. class Eve(Optimizer): """ @@ -1057,5 +1098,6 @@ if __name__ == "__main__": else: hidden_dim = 200 - _test_scaled_adam(hidden_dim) - _test_eden() + # _test_scaled_adam(hidden_dim) + # _test_eden() + _plot_eden_lr()