mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
4fa0ec342b
@ -21,9 +21,9 @@ tree $repo/
|
||||
ls -lh $repo/test_wavs/*.wav
|
||||
|
||||
pushd $repo/exp
|
||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/L.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/LG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/HLG.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/Linv.pt"
|
||||
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||
git lfs pull --include "exp/cpu_jit.pt"
|
4
.github/workflows/run-aishell-2022-06-20.yml
vendored
4
.github/workflows/run-aishell-2022-06-20.yml
vendored
@ -44,7 +44,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -119,5 +119,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-06-20
|
||||
name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-06-20
|
||||
path: egs/aishell/ASR/pruned_transducer_stateless3/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -122,5 +122,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12
|
||||
path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless-2022-03-12
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless-2022-03-12
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -174,12 +174,12 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
||||
- name: Upload decoding results for pruned_transducer_stateless3
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless5-2022-05-13
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless5-2022-05-13
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless5/exp/
|
||||
|
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-2022-11-11
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7/exp/
|
||||
|
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless8-2022-11-14
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless8-2022-11-14
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless8/exp/
|
||||
|
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-2022-12-01
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/
|
||||
|
@ -163,5 +163,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer_mmi-2022-12-08
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer_mmi-2022-12-08
|
||||
path: egs/librispeech/ASR/zipformer_mmi/exp/
|
||||
|
@ -168,5 +168,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-streaming-2022-12-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-streaming-2022-12-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_streaming/exp/
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: run-librispeech-2022-12-15-stateless7-ctc-bs
|
||||
name: run-librispeech-2023-01-29-stateless7-ctc-bs
|
||||
# zipformer
|
||||
|
||||
on:
|
||||
@ -34,7 +34,7 @@ on:
|
||||
- cron: "50 15 * * *"
|
||||
|
||||
jobs:
|
||||
run_librispeech_2022_12_15_zipformer_ctc_bs:
|
||||
run_librispeech_2023_01_29_zipformer_ctc_bs:
|
||||
if: github.event.label.name == 'run-decode' || github.event.label.name == 'blank-skip' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
@ -124,7 +124,7 @@ jobs:
|
||||
export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH
|
||||
export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH
|
||||
|
||||
.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2022-12-15.sh
|
||||
.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-bs-2023-01-29.sh
|
||||
|
||||
- name: Display decoding results for librispeech pruned_transducer_stateless7_ctc_bs
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-bs-2022-12-15
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless7-ctc-bs-2023-01-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/exp/
|
@ -151,5 +151,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-conformer_ctc3-2022-11-28
|
||||
path: egs/librispeech/ASR/conformer_ctc3/exp/
|
||||
|
@ -26,7 +26,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
@ -159,5 +159,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03
|
||||
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -153,5 +153,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless3-2022-04-29
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless3/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless2-2022-06-26
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-pruned_transducer_stateless2-2022-06-26
|
||||
path: egs/librispeech/ASR/pruned_transducer_stateless2/exp/
|
||||
|
@ -170,5 +170,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
||||
|
@ -43,7 +43,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless2-2022-04-19
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless2-2022-04-19
|
||||
path: egs/librispeech/ASR/transducer_stateless2/exp/
|
||||
|
@ -155,5 +155,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
||||
|
@ -151,5 +151,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-zipformer-2022-11-11
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11
|
||||
path: egs/librispeech/ASR/zipformer/exp/
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-02-21
|
||||
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless_multi_datasets-100h-2022-03-01
|
||||
path: egs/librispeech/ASR/transducer_stateless_multi_datasets/exp/
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
@ -42,7 +42,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
@ -154,5 +154,5 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
if: github.event_name == 'schedule' || github.event.label.name == 'run-decode'
|
||||
with:
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-transducer_stateless-2022-02-07
|
||||
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-transducer_stateless-2022-02-07
|
||||
path: egs/librispeech/ASR/transducer_stateless/exp/
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.7, 3.8, 3.9]
|
||||
|
||||
fail-fast: false
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
||||
fail-fast: false
|
||||
|
2
.github/workflows/run-yesno-recipe.yml
vendored
2
.github/workflows/run-yesno-recipe.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
# os: [ubuntu-18.04, macos-10.15]
|
||||
# os: [ubuntu-latest, macos-10.15]
|
||||
# TODO: enable macOS for CPU testing
|
||||
os: [ubuntu-latest]
|
||||
python-version: [3.8]
|
||||
|
@ -86,6 +86,7 @@ rst_epilog = """
|
||||
.. _git-lfs: https://git-lfs.com/
|
||||
.. _ncnn: https://github.com/tencent/ncnn
|
||||
.. _LibriSpeech: https://www.openslr.org/12
|
||||
.. _Gigaspeech: https://github.com/SpeechColab/GigaSpeech
|
||||
.. _musan: http://www.openslr.org/17/
|
||||
.. _ONNX: https://github.com/onnx/onnx
|
||||
.. _onnxruntime: https://github.com/microsoft/onnxruntime
|
||||
|
184
docs/source/decoding-with-langugage-models/LODR.rst
Normal file
184
docs/source/decoding-with-langugage-models/LODR.rst
Normal file
@ -0,0 +1,184 @@
|
||||
.. _LODR:
|
||||
|
||||
LODR for RNN Transducer
|
||||
=======================
|
||||
|
||||
|
||||
As a type of E2E model, neural transducers are usually considered as having an internal
|
||||
language model, which learns the language level information on the training corpus.
|
||||
In real-life scenario, there is often a mismatch between the training corpus and the target corpus space.
|
||||
This mismatch can be a problem when decoding for neural transducer models with language models as its internal
|
||||
language can act "against" the external LM. In this tutorial, we show how to use
|
||||
`Low-order Density Ratio <https://arxiv.org/abs/2203.16776>`_ to alleviate this effect to further improve the performance
|
||||
of langugae model integration.
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the recipe
|
||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||
However, you can easily apply LODR to other recipes.
|
||||
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
For simplicity, the training and testing corpus in this tutorial are the same (`LibriSpeech`_). However,
|
||||
you can change the testing set to any other domains (e.g `GigaSpeech`_) and prepare the language models
|
||||
using that corpus.
|
||||
|
||||
First, let's have a look at some background information. As the predecessor of LODR, Density Ratio (DR) is first proposed `here <https://arxiv.org/abs/2002.11268>`_
|
||||
to address the language information mismatch between the training
|
||||
corpus (source domain) and the testing corpus (target domain). Assuming that the source domain and the test domain
|
||||
are acoustically similar, DR derives the following formular for decoding with Bayes' theorem:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{score}\left(y_u|\mathit{x},y\right) =
|
||||
\log p\left(y_u|\mathit{x},y_{1:u-1}\right) +
|
||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||
\lambda_2 \log p_{\text{Source LM}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||
|
||||
|
||||
where :math:`\lambda_1` and :math:`\lambda_2` are the weights of LM scores for target domain and source domain respectively.
|
||||
Here, the source domain LM is trained on the training corpus. The only difference in the above formular compared to
|
||||
shallow fusion is the subtraction of the source domain LM.
|
||||
|
||||
Some works treat the predictor and the joiner of the neural transducer as its internal LM. However, the LM is
|
||||
considered to be weak and can only capture low-level language information. Therefore, `LODR <https://arxiv.org/abs/2203.16776>`__ proposed to use
|
||||
a low-order n-gram LM as an approximation of the ILM of the neural transducer. This leads to the following formula
|
||||
during decoding for transducer model:
|
||||
|
||||
.. math::
|
||||
|
||||
\text{score}\left(y_u|\mathit{x},y\right) =
|
||||
\log p_{rnnt}\left(y_u|\mathit{x},y_{1:u-1}\right) +
|
||||
\lambda_1 \log p_{\text{Target LM}}\left(y_u|\mathit{x},y_{1:u-1}\right) -
|
||||
\lambda_2 \log p_{\text{bi-gram}}\left(y_u|\mathit{x},y_{1:u-1}\right)
|
||||
|
||||
In LODR, an additional bi-gram LM estimated on the source domain (e.g training corpus) is required. Comared to DR,
|
||||
the only difference lies in the choice of source domain LM. According to the original `paper <https://arxiv.org/abs/2203.16776>`_,
|
||||
LODR achieves similar performance compared DR in both intra-domain and cross-domain settings.
|
||||
As a bi-gram is much faster to evaluate, LODR is usually much faster.
|
||||
|
||||
Now, we will show you how to use LODR in ``icefall``.
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
The testing scenario here is intra-domain (we decode the model trained on `LibriSpeech`_ on `LibriSpeech`_ testing sets).
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
To test the model, let's have a look at the decoding results **without** using LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
Then, we download the external language model and bi-gram LM that are necessary for LODR.
|
||||
Note that the bi-gram is estimated on the LibriSpeech 960 hours' text.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
$
|
||||
$ # download the bi-gram
|
||||
$ git lfs install
|
||||
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
|
||||
$ pushd data/lang_bpe_500
|
||||
$ ln -s ../../librispeech_bigram/2gram.fst.txt .
|
||||
$ popd
|
||||
|
||||
Then, we perform LODR decoding by setting ``--decoding-method`` to ``modified_beam_search_lm_LODR``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.42
|
||||
$ LODR_scale=-0.24
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_LODR \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--use-shallow-fusion 1 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500 \
|
||||
--tokens-ngram 2 \
|
||||
--ngram-lm-scale $LODR_scale
|
||||
|
||||
There are two extra arguments that need to be given when doing LODR. ``--tokens-ngram`` specifies the order of n-gram. As we
|
||||
are using a bi-gram, we set it to 2. ``--ngram-lm-scale`` is the scale of the bi-gram, it should be a negative number
|
||||
as we are subtracting the bi-gram's score during decoding.
|
||||
|
||||
The decoding results obtained with the above command are shown below:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.61 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 6.74 best for test-other
|
||||
|
||||
Recall that the lowest WER we obtained in :ref:`shallow_fusion` with beam size of 4 is ``2.77/7.08``, LODR
|
||||
indeed **further improves** the WER. We can do even better if we increase ``--beam-size``:
|
||||
|
||||
.. list-table:: WER of LODR with different beam sizes
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.61
|
||||
- 6.74
|
||||
* - 8
|
||||
- 2.45
|
||||
- 6.38
|
||||
* - 12
|
||||
- 2.4
|
||||
- 6.23
|
12
docs/source/decoding-with-langugage-models/index.rst
Normal file
12
docs/source/decoding-with-langugage-models/index.rst
Normal file
@ -0,0 +1,12 @@
|
||||
Decoding with language models
|
||||
=============================
|
||||
|
||||
This section describes how to use external langugage models
|
||||
during decoding to improve the WER of transducer models.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
shallow-fusion
|
||||
LODR
|
||||
rescoring
|
252
docs/source/decoding-with-langugage-models/rescoring.rst
Normal file
252
docs/source/decoding-with-langugage-models/rescoring.rst
Normal file
@ -0,0 +1,252 @@
|
||||
.. _rescoring:
|
||||
|
||||
LM rescoring for Transducer
|
||||
=================================
|
||||
|
||||
LM rescoring is a commonly used approach to incorporate external LM information. Unlike shallow-fusion-based
|
||||
methods (see :ref:`shallow-fusion`, :ref:`LODR`), rescoring is usually performed to re-rank the n-best hypotheses after beam search.
|
||||
Rescoring is usually more efficient than shallow fusion since less computation is performed on the external LM.
|
||||
In this tutorial, we will show you how to use external LM to rescore the n-best hypotheses decoded from neural transducer models in
|
||||
`icefall <https://github.com/k2-fsa/icefall>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the recipe
|
||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||
However, you can easily apply shallow fusion to other recipes.
|
||||
If you encounter any problems, please open an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
|
||||
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU for decoding.
|
||||
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
As usual, we first test the model's performance without external LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
Now, we will try to improve the above WER numbers via external LM rescoring. We will download
|
||||
a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
|
||||
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
|
||||
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
|
||||
|
||||
With the RNNLM available, we can rescore the n-best hypotheses generated from `modified_beam_search`. Here,
|
||||
`n` should be the number of beams, i.e ``--beam-size``. The command for LM rescoring is
|
||||
as follows. Note that the ``--decoding-method`` is set to `modified_beam_search_lm_rescore` and ``--use-shallow-fusion``
|
||||
is set to `False`.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.43
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_rescore \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--use-shallow-fusion 0 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.93 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.6 best for test-other
|
||||
|
||||
Great! We made some improvements! Increasing the size of the n-best hypotheses will further boost the performance,
|
||||
see the following table:
|
||||
|
||||
.. list-table:: WERs of LM rescoring with different beam sizes
|
||||
:widths: 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.93
|
||||
- 7.6
|
||||
* - 8
|
||||
- 2.67
|
||||
- 7.11
|
||||
* - 12
|
||||
- 2.59
|
||||
- 6.86
|
||||
|
||||
In fact, we can also apply LODR (see :ref:`LODR`) when doing LM rescoring. To do so, we need to
|
||||
download the bi-gram required by LODR:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the bi-gram
|
||||
$ git lfs install
|
||||
$ git clone https://huggingface.co/marcoyang/librispeech_bigram
|
||||
$ pushd data/lang_bpe_500
|
||||
$ ln -s ../../librispeech_bigram/2gram.arpa .
|
||||
$ popd
|
||||
|
||||
Then we can performn LM rescoring + LODR by changing the decoding method to `modified_beam_search_lm_rescore_LODR`.
|
||||
|
||||
.. note::
|
||||
|
||||
This decoding method requires the dependency of `kenlm <https://github.com/kpu/kenlm>`_. You can install it
|
||||
via this command: `pip install https://github.com/kpu/kenlm/archive/master.zip`.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.43
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_rescore_LODR \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--use-shallow-fusion 0 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
You should see the following WERs after executing the commands above:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.9 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.57 best for test-other
|
||||
|
||||
It's slightly better than LM rescoring. If we further increase the beam size, we will see
|
||||
further improvements from LM rescoring + LODR:
|
||||
|
||||
.. list-table:: WERs of LM rescoring + LODR with different beam sizes
|
||||
:widths: 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
* - 4
|
||||
- 2.9
|
||||
- 7.57
|
||||
* - 8
|
||||
- 2.63
|
||||
- 7.04
|
||||
* - 12
|
||||
- 2.52
|
||||
- 6.73
|
||||
|
||||
As mentioned earlier, LM rescoring is usually faster than shallow-fusion based methods.
|
||||
Here, we benchmark the WERs and decoding speed of them:
|
||||
|
||||
.. list-table:: LM-rescoring-based methods vs shallow-fusion-based methods (The numbers in each field is WER on test-clean, WER on test-other and decoding time on test-clean)
|
||||
:widths: 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Decoding method
|
||||
- beam=4
|
||||
- beam=8
|
||||
- beam=12
|
||||
* - `modified_beam_search`
|
||||
- 3.11/7.93; 132s
|
||||
- 3.1/7.95; 177s
|
||||
- 3.1/7.96; 210s
|
||||
* - `modified_beam_search_lm_shallow_fusion`
|
||||
- 2.77/7.08; 262s
|
||||
- 2.62/6.65; 352s
|
||||
- 2.58/6.65; 488s
|
||||
* - LODR
|
||||
- 2.61/6.74; 400s
|
||||
- 2.45/6.38; 610s
|
||||
- 2.4/6.23; 870s
|
||||
* - `modified_beam_search_lm_rescore`
|
||||
- 2.93/7.6; 156s
|
||||
- 2.67/7.11; 203s
|
||||
- 2.59/6.86; 255s
|
||||
* - `modified_beam_search_lm_rescore_LODR`
|
||||
- 2.9/7.57; 160s
|
||||
- 2.63/7.04; 203s
|
||||
- 2.52/6.73; 263s
|
||||
|
||||
.. note::
|
||||
|
||||
Decoding is performed with a single 32G V100, we set ``--max-duration`` to 600.
|
||||
Decoding time here is only for reference and it may vary.
|
176
docs/source/decoding-with-langugage-models/shallow-fusion.rst
Normal file
176
docs/source/decoding-with-langugage-models/shallow-fusion.rst
Normal file
@ -0,0 +1,176 @@
|
||||
.. _shallow_fusion:
|
||||
|
||||
Shallow fusion for Transducer
|
||||
=================================
|
||||
|
||||
External language models (LM) are commonly used to improve WERs for E2E ASR models.
|
||||
This tutorial shows you how to perform ``shallow fusion`` with an external LM
|
||||
to improve the word-error-rate of a transducer model.
|
||||
|
||||
.. note::
|
||||
|
||||
This tutorial is based on the recipe
|
||||
`pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
which is a streaming transducer model trained on `LibriSpeech`_.
|
||||
However, you can easily apply shallow fusion to other recipes.
|
||||
If you encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
|
||||
.. note::
|
||||
|
||||
For simplicity, the training and testing corpus in this tutorial is the same (`LibriSpeech`_). However, you can change the testing set
|
||||
to any other domains (e.g `GigaSpeech`_) and use an external LM trained on that domain.
|
||||
|
||||
.. HINT::
|
||||
|
||||
We recommend you to use a GPU for decoding.
|
||||
|
||||
For illustration purpose, we will use a pre-trained ASR model from this `link <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__.
|
||||
If you want to train your model from scratch, please have a look at :ref:`non_streaming_librispeech_pruned_transducer_stateless`.
|
||||
|
||||
As the initial step, let's download the pre-trained model.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||
$ pushd icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt # create a symbolic link so that the checkpoint can be loaded
|
||||
|
||||
To test the model, let's have a look at the decoding results without using LM. This can be done via the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp/
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--exp-dir $exp_dir \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search
|
||||
|
||||
The following WERs are achieved on test-clean and test-other:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 3.11 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.93 best for test-other
|
||||
|
||||
These are already good numbers! But we can further improve it by using shallow fusion with external LM.
|
||||
Training a language model usually takes a long time, we can download a pre-trained LM from this `link <https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm>`__.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ # download the external LM
|
||||
$ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm
|
||||
$ # create a symbolic link so that the checkpoint can be loaded
|
||||
$ pushd icefall-librispeech-rnn-lm/exp
|
||||
$ git lfs pull --include "pretrained.pt"
|
||||
$ ln -s pretrained.pt epoch-99.pt
|
||||
$ popd
|
||||
|
||||
.. note::
|
||||
|
||||
This is an RNN LM trained on the LibriSpeech text corpus. So it might not be ideal for other corpus.
|
||||
You may also train a RNN LM from scratch. Please refer to this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py>`__
|
||||
for training a RNN LM and this `script <https://github.com/k2-fsa/icefall/blob/master/icefall/transformer_lm/train.py>`__ to train a transformer LM.
|
||||
|
||||
To use shallow fusion for decoding, we can execute the following command:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ exp_dir=./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||
$ lm_dir=./icefall-librispeech-rnn-lm/exp
|
||||
$ lm_scale=0.29
|
||||
$ ./pruned_transducer_stateless7_streaming/decode.py \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--use-averaged-model False \
|
||||
--beam-size 4 \
|
||||
--exp-dir $exp_dir \
|
||||
--max-duration 600 \
|
||||
--decode-chunk-len 32 \
|
||||
--decoding-method modified_beam_search_lm_shallow_fusion \
|
||||
--bpe-model ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/data/lang_bpe_500/bpe.model
|
||||
--use-shallow-fusion 1 \
|
||||
--lm-type rnn \
|
||||
--lm-exp-dir $lm_dir \
|
||||
--lm-epoch 99 \
|
||||
--lm-scale $lm_scale \
|
||||
--lm-avg 1 \
|
||||
--rnn-lm-embedding-dim 2048 \
|
||||
--rnn-lm-hidden-dim 2048 \
|
||||
--rnn-lm-num-layers 3 \
|
||||
--lm-vocab-size 500
|
||||
|
||||
Note that we set ``--decoding-method modified_beam_search_lm_shallow_fusion`` and ``--use-shallow-fusion True``
|
||||
to use shallow fusion. ``--lm-type`` specifies the type of neural LM we are going to use, you can either choose
|
||||
between ``rnn`` or ``transformer``. The following three arguments are associated with the rnn:
|
||||
|
||||
- ``--rnn-lm-embedding-dim``
|
||||
The embedding dimension of the RNN LM
|
||||
|
||||
- ``--rnn-lm-hidden-dim``
|
||||
The hidden dimension of the RNN LM
|
||||
|
||||
- ``--rnn-lm-num-layers``
|
||||
The number of RNN layers in the RNN LM.
|
||||
|
||||
|
||||
The decoding result obtained with the above command are shown below.
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
$ For test-clean, WER of different settings are:
|
||||
$ beam_size_4 2.77 best for test-clean
|
||||
$ For test-other, WER of different settings are:
|
||||
$ beam_size_4 7.08 best for test-other
|
||||
|
||||
The improvement of shallow fusion is very obvious! The relative WER reduction on test-other is around 10.5%.
|
||||
A few parameters can be tuned to further boost the performance of shallow fusion:
|
||||
|
||||
- ``--lm-scale``
|
||||
|
||||
Controls the scale of the LM. If too small, the external language model may not be fully utilized; if too large,
|
||||
the LM score may dominant during decoding, leading to bad WER. A typical value of this is around 0.3.
|
||||
|
||||
- ``--beam-size``
|
||||
|
||||
The number of active paths in the search beam. It controls the trade-off between decoding efficiency and accuracy.
|
||||
|
||||
Here, we also show how `--beam-size` effect the WER and decoding time:
|
||||
|
||||
.. list-table:: WERs and decoding time (on test-clean) of shallow fusion with different beam sizes
|
||||
:widths: 25 25 25 25
|
||||
:header-rows: 1
|
||||
|
||||
* - Beam size
|
||||
- test-clean
|
||||
- test-other
|
||||
- Decoding time on test-clean (s)
|
||||
* - 4
|
||||
- 2.77
|
||||
- 7.08
|
||||
- 262
|
||||
* - 8
|
||||
- 2.62
|
||||
- 6.65
|
||||
- 352
|
||||
* - 12
|
||||
- 2.58
|
||||
- 6.65
|
||||
- 488
|
||||
|
||||
As we see, a larger beam size during shallow fusion improves the WER, but is also slower.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -34,3 +34,8 @@ speech recognition recipes using `k2 <https://github.com/k2-fsa/k2>`_.
|
||||
|
||||
contributing/index
|
||||
huggingface/index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
decoding-with-langugage-models/index
|
@ -1,7 +1,7 @@
|
||||
Distillation with HuBERT
|
||||
========================
|
||||
|
||||
This tutorial shows you how to perform knowledge distillation in `icefall`_
|
||||
This tutorial shows you how to perform knowledge distillation in `icefall <https://github.com/k2-fsa/icefall>`_
|
||||
with the `LibriSpeech`_ dataset. The distillation method
|
||||
used here is called "Multi Vector Quantization Knowledge Distillation" (MVQ-KD).
|
||||
Please have a look at our paper `Predicting Multi-Codebook Vector Quantization Indexes for Knowledge Distillation <https://arxiv.org/abs/2211.00508>`_
|
||||
@ -13,7 +13,7 @@ for more details about MVQ-KD.
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_.
|
||||
Currently, we only implement MVQ-KD in this recipe. However, MVQ-KD is theoretically applicable to all recipes
|
||||
with only minor changes needed. Feel free to try out MVQ-KD in different recipes. If you
|
||||
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
encounter any problems, please open an issue here `icefall <https://github.com/k2-fsa/icefall/issues>`__.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -217,7 +217,7 @@ the following command.
|
||||
--exp-dir $exp_dir \
|
||||
--enable-distillation True
|
||||
|
||||
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`_.
|
||||
You should get similar results as `here <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS-100hours.md#distillation-with-hubert>`__.
|
||||
|
||||
That's all! Feel free to experiment with your own setups and report your results.
|
||||
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`_.
|
||||
If you encounter any problems during training, please open up an issue `here <https://github.com/k2-fsa/icefall/issues>`__.
|
||||
|
@ -8,10 +8,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
|
||||
We will take pruned_transducer_stateless4 as an example in this tutorial.
|
||||
|
||||
.. HINT::
|
||||
@ -237,7 +237,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
|
||||
|
||||
.. NOTE::
|
||||
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
|
||||
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
|
||||
|
||||
|
||||
@ -529,13 +529,13 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following links:
|
||||
|
||||
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`_
|
||||
- `pruned_transducer_stateless <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>`__
|
||||
|
||||
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`_
|
||||
- `pruned_transducer_stateless2 <https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29>`__
|
||||
|
||||
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`_
|
||||
- `pruned_transducer_stateless4 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless4-2022-06-03>`__
|
||||
|
||||
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`_
|
||||
- `pruned_transducer_stateless5 <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless5-2022-07-07>`__
|
||||
|
||||
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||
for the details of the above pretrained models
|
||||
|
@ -45,9 +45,9 @@ the input features.
|
||||
|
||||
We have three variants of Emformer models in ``icefall``.
|
||||
|
||||
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`_.
|
||||
- ``pruned_stateless_emformer_rnnt2`` using Emformer from torchaudio, see `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2>`__.
|
||||
- ``conv_emformer_transducer_stateless`` using ConvEmformer implemented by ourself. Different from the Emformer in torchaudio,
|
||||
ConvEmformer has a convolution in each layer and uses the mechanisms in our reworked conformer model.
|
||||
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`_.
|
||||
See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless>`__.
|
||||
- ``conv_emformer_transducer_stateless2`` using ConvEmformer implemented by ourself. The only difference from the above one is that
|
||||
it uses a simplified memory bank. See `LibriSpeech recipe <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/conv_emformer_transducer_stateless2>`_.
|
||||
|
@ -6,10 +6,10 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`_,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`_,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`_,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless>`__,
|
||||
`pruned_transducer_stateless2 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless2>`__,
|
||||
`pruned_transducer_stateless4 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless4>`__,
|
||||
`pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless5>`__,
|
||||
We will take pruned_transducer_stateless4 as an example in this tutorial.
|
||||
|
||||
.. HINT::
|
||||
@ -264,7 +264,7 @@ them, please modify ``./pruned_transducer_stateless4/train.py`` directly.
|
||||
|
||||
.. NOTE::
|
||||
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`_ are a little different from
|
||||
The options for `pruned_transducer_stateless5 <https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless5/train.py>`__ are a little different from
|
||||
other recipes. It allows you to configure ``--num-encoder-layers``, ``--dim-feedforward``, ``--nhead``, ``--encoder-dim``, ``--decoder-dim``, ``--joiner-dim`` from commandline, so that you can train models with different size with pruned_transducer_stateless5.
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ with the `LibriSpeech <https://www.openslr.org/12>`_ dataset.
|
||||
|
||||
.. Note::
|
||||
|
||||
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`_,
|
||||
The tutorial is suitable for `pruned_transducer_stateless7_streaming <https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming>`__,
|
||||
|
||||
.. HINT::
|
||||
|
||||
@ -642,7 +642,7 @@ Download pretrained models
|
||||
If you don't want to train from scratch, you can download the pretrained models
|
||||
by visiting the following links:
|
||||
|
||||
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`_
|
||||
- `pruned_transducer_stateless7_streaming <https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29>`__
|
||||
|
||||
See `<https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/RESULTS.md>`_
|
||||
for the details of the above pretrained models
|
||||
|
@ -240,7 +240,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless3/exp",
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
@ -243,7 +243,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless3/exp",
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
156
egs/ami/SURT/README.md
Normal file
156
egs/ami/SURT/README.md
Normal file
@ -0,0 +1,156 @@
|
||||
# Introduction
|
||||
|
||||
This is a multi-talker ASR recipe for the AMI and ICSI datasets. We train a Streaming
|
||||
Unmixing and Recognition Transducer (SURT) model for the task.
|
||||
|
||||
Please refer to the `egs/libricss/SURT` recipe README for details about the task and the
|
||||
model.
|
||||
|
||||
## Description of the recipe
|
||||
|
||||
### Pre-requisites
|
||||
|
||||
The recipes in this directory need the following packages to be installed:
|
||||
|
||||
- [meeteval](https://github.com/fgnt/meeteval)
|
||||
- [einops](https://github.com/arogozhnikov/einops)
|
||||
|
||||
Additionally, we initialize the model with the pre-trained model from the LibriCSS recipe.
|
||||
Please download this checkpoint (see below) or train the LibriCSS recipe first.
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, run the following from within `egs/ami/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
python dprnn_zipformer/train.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 30 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--heat-loss-scale 0.2 \
|
||||
--base-lr 0.004 \
|
||||
--model-init-ckpt exp/libricss_base.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
The above is for SURT-base (~26M). For SURT-large (~38M), use:
|
||||
|
||||
```bash
|
||||
--model-init-ckpt exp/libricss_large.pt \
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt exp/zipformer_large.pt \
|
||||
```
|
||||
|
||||
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
|
||||
|
||||
### Adaptation
|
||||
|
||||
The training step above only trains on simulated mixtures. For best results, we also
|
||||
adapt the final model on the AMI+ICSI train set. For this, run the following from within
|
||||
`egs/ami/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/train_adapt.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 8 \
|
||||
--lr-epochs 2 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--base-lr 0.0004 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
For SURT-large, use the following config:
|
||||
|
||||
```bash
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
|
||||
--num-epochs 15 \
|
||||
--lr-epochs 4 \
|
||||
```
|
||||
|
||||
|
||||
### Decoding
|
||||
|
||||
To decode the model, run the following from within `egs/ami/SURT`:
|
||||
|
||||
#### Greedy search
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 20 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
|
||||
#### Beam search
|
||||
|
||||
```bash
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 20 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
## Results (using beam search)
|
||||
|
||||
**AMI**
|
||||
|
||||
| Model | IHM-Mix | SDM | MDM |
|
||||
|------------|:-------:|:----:|:----:|
|
||||
| SURT-base | 39.8 | 65.4 | 46.6 |
|
||||
| + adapt | 37.4 | 46.9 | 43.7 |
|
||||
| SURT-large | 36.8 | 62.5 | 44.4 |
|
||||
| + adapt | **35.1** | **44.6** | **41.4** |
|
||||
|
||||
**ICSI**
|
||||
|
||||
| Model | IHM-Mix | SDM |
|
||||
|------------|:-------:|:----:|
|
||||
| SURT-base | 28.3 | 60.0 |
|
||||
| + adapt | 26.3 | 33.9 |
|
||||
| SURT-large | 27.8 | 59.7 |
|
||||
| + adapt | **24.4** | **32.3** |
|
||||
|
||||
## Pre-trained models and logs
|
||||
|
||||
* LibriCSS pre-trained model (for initialization): [base](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_base) [large](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_large)
|
||||
|
||||
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-ami-dprnn-zipformer>
|
||||
|
||||
* Training logs:
|
||||
- surt_base: <https://tensorboard.dev/experiment/8awy98VZSWegLmH4l2JWSA/>
|
||||
- surt_base_adapt: <https://tensorboard.dev/experiment/aGVgXVzYRDKbGUbPekcNjg/>
|
||||
- surt_large: <https://tensorboard.dev/experiment/ZXMkez0VSYKbPLqRk4clOQ/>
|
||||
- surt_large_adapt: <https://tensorboard.dev/experiment/WLKL1e7bTVyEjSonYSNYwg/>
|
399
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
399
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,399 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SurtDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AmiAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 SURT experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration-valid",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-cuts",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of cuts in a single batch. You can "
|
||||
"reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help=(
|
||||
"When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
sources: bool = False,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=sources,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
validate = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def aimix_train_cuts(
|
||||
self,
|
||||
rvb_affix: str = "clean",
|
||||
sources: bool = True,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
source_affix = "_sources" if sources else ""
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_train_{rvb_affix}{source_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(
|
||||
self,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_ami_icsi.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def ami_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
|
||||
logging.info(f"About to get AMI {split} {type} cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_ami-{type}_{split}.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def icsi_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
|
||||
logging.info(f"About to get ICSI {split} {type} cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_icsi-{type}_{split}.jsonl.gz"
|
||||
)
|
1
egs/ami/SURT/dprnn_zipformer/beam_search.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/beam_search.py
|
622
egs/ami/SURT/dprnn_zipformer/decode.py
Executable file
622
egs/ami/SURT/dprnn_zipformer/decode.py
Executable file
@ -0,0 +1,622 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AmiAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.utils import EPSILON
|
||||
from train import add_model_arguments, get_params, get_surt_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_surt_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="dprnn_zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
feature_lens = batch["input_lens"].to(device)
|
||||
|
||||
# Apply the mask encoder
|
||||
B, T, F = feature.shape
|
||||
processed = model.mask_encoder(feature) # B,T,F*num_channels
|
||||
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
||||
x_masked = [feature * m for m in masks]
|
||||
|
||||
# Recognition
|
||||
# Stack the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||
|
||||
if model.joint_encoder_layer is not None:
|
||||
encoder_out = model.joint_encoder_layer(encoder_out)
|
||||
|
||||
def _group_channels(hyps: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Currently we have a batch of size M*B, where M is the number of
|
||||
channels and B is the batch size. We need to group the hypotheses
|
||||
into B groups, each of which contains M hypotheses.
|
||||
|
||||
Example:
|
||||
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
|
||||
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
|
||||
"""
|
||||
assert len(hyps) == B * params.num_channels
|
||||
out_hyps = []
|
||||
for i in range(B):
|
||||
out_hyps.append(hyps[i::B])
|
||||
return out_hyps
|
||||
|
||||
hyps = []
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp))
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": _group_channels(hyps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: _group_channels(hyps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||
cuts_batch = batch["cuts"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
for cut_id, hyp_words in zip(cut_ids, hyps):
|
||||
# Reference is a list of supervision texts sorted by start time.
|
||||
ref_words = [
|
||||
s.text.strip()
|
||||
for s in sorted(
|
||||
cuts_batch[cut_id].supervisions, key=lambda s: s.start
|
||||
)
|
||||
]
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(cut_ids)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_surt_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
num_channels=params.num_channels,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LmScorer.add_arguments(parser)
|
||||
AmiAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
), f"Decoding method {params.decoding_method} is not supported."
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_surt_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ami = AmiAsrDataModule(args)
|
||||
|
||||
# NOTE(@desh2608): we filter segments longer than 120s to avoid OOM errors in decoding.
|
||||
# However, 99.9% of the segments are shorter than 120s, so this should not
|
||||
# substantially affect the results. In future, we will implement an overlapped
|
||||
# inference method to avoid OOM errors.
|
||||
|
||||
test_sets = {}
|
||||
for split in ["dev", "test"]:
|
||||
for type in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
test_sets[f"ami-{split}_{type}"] = (
|
||||
ami.ami_cuts(split=split, type=type)
|
||||
.trim_to_supervision_groups(max_pause=0.0)
|
||||
.filter(lambda c: 0.1 < c.duration < 120.0)
|
||||
.to_eager()
|
||||
)
|
||||
|
||||
for split in ["dev", "test"]:
|
||||
for type in ["ihm-mix", "sdm"]:
|
||||
test_sets[f"icsi-{split}_{type}"] = (
|
||||
ami.icsi_cuts(split=split, type=type)
|
||||
.trim_to_supervision_groups(max_pause=0.0)
|
||||
.filter(lambda c: 0.1 < c.duration < 120.0)
|
||||
.to_eager()
|
||||
)
|
||||
|
||||
for test_set, test_cuts in test_sets.items():
|
||||
test_dl = ami.test_dataloaders(test_cuts)
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/ami/SURT/dprnn_zipformer/decoder.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/decoder.py
|
1
egs/ami/SURT/dprnn_zipformer/dprnn.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/dprnn.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/dprnn.py
|
1
egs/ami/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/encoder_interface.py
|
1
egs/ami/SURT/dprnn_zipformer/export.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/export.py
|
1
egs/ami/SURT/dprnn_zipformer/joiner.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/joiner.py
|
1
egs/ami/SURT/dprnn_zipformer/model.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/model.py
|
1
egs/ami/SURT/dprnn_zipformer/optim.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/optim.py
|
1
egs/ami/SURT/dprnn_zipformer/scaling.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/scaling.py
|
1
egs/ami/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/scaling_converter.py
|
1
egs/ami/SURT/dprnn_zipformer/test_model.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/test_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py
|
1420
egs/ami/SURT/dprnn_zipformer/train.py
Executable file
1420
egs/ami/SURT/dprnn_zipformer/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1411
egs/ami/SURT/dprnn_zipformer/train_adapt.py
Executable file
1411
egs/ami/SURT/dprnn_zipformer/train_adapt.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/ami/SURT/dprnn_zipformer/zipformer.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/zipformer.py
|
78
egs/ami/SURT/local/add_source_feats.py
Executable file
78
egs/ami/SURT/local/add_source_feats.py
Executable file
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file adds source features as temporal arrays to the mixture manifests.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def add_source_feats():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
logging.info("Reading mixed cuts")
|
||||
mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz")
|
||||
mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz")
|
||||
|
||||
logging.info("Reading source cuts")
|
||||
source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz")
|
||||
|
||||
logging.info("Adding source features to the mixed cuts")
|
||||
pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features")
|
||||
with CutSet.open_writer(
|
||||
src_dir / "cuts_train_clean_sources.jsonl.gz"
|
||||
) as cut_writer_clean, CutSet.open_writer(
|
||||
src_dir / "cuts_train_reverb_sources.jsonl.gz"
|
||||
) as cut_writer_reverb, LilcomChunkyWriter(
|
||||
output_dir / "feats_train_clean_sources"
|
||||
) as source_feat_writer:
|
||||
for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb):
|
||||
assert cut_reverb.id == cut_clean.id + "_rvb"
|
||||
source_feats = []
|
||||
source_feat_offsets = []
|
||||
cur_offset = 0
|
||||
for sup in sorted(
|
||||
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
|
||||
):
|
||||
source_cut = source_cuts[sup.id]
|
||||
source_feats.append(source_cut.load_features())
|
||||
source_feat_offsets.append(cur_offset)
|
||||
cur_offset += source_cut.num_frames
|
||||
cut_clean.source_feats = source_feat_writer.store_array(
|
||||
cut_clean.id, np.concatenate(source_feats, axis=0)
|
||||
)
|
||||
cut_clean.source_feat_offsets = source_feat_offsets
|
||||
cut_writer_clean.write(cut_clean)
|
||||
# Also write the reverb cut
|
||||
cut_reverb.source_feats = cut_clean.source_feats
|
||||
cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets
|
||||
cut_writer_reverb.write(cut_reverb)
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
add_source_feats()
|
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the synthetically mixed AMI and ICSI
|
||||
train set.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torchaudio
|
||||
from lhotse import (
|
||||
AudioSource,
|
||||
LilcomChunkyWriter,
|
||||
Recording,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
)
|
||||
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||
from lhotse.cut import MixedCut, MixTrack, MultiCut
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed, uuid4
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
set_ffmpeg_torchaudio_info_enabled(False)
|
||||
|
||||
|
||||
def compute_fbank_aimix():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz")
|
||||
|
||||
# only uses RIRs and noises from REVERB challenge
|
||||
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
|
||||
# Apply perturbation to the training cuts
|
||||
logging.info("Applying perturbation to the training cuts")
|
||||
train_cuts_rvb = train_cuts.map(
|
||||
lambda c: augment(
|
||||
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Extracting fbank features for training cuts")
|
||||
_ = train_cuts.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "ai-mix_feats_clean",
|
||||
manifest_path=src_dir / "cuts_train_clean.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
_ = train_cuts_rvb.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "ai-mix_feats_reverb",
|
||||
manifest_path=src_dir / "cuts_train_reverb.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
|
||||
"""
|
||||
Given a mixed cut, this function optionally applies the following augmentations:
|
||||
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
|
||||
- Reverberation using a randomly selected RIR
|
||||
- Adding noise
|
||||
- Perturbing the loudness (in range [-20, -25] dB)
|
||||
"""
|
||||
out_cut = cut.drop_features()
|
||||
|
||||
# Perturb the SNRs (optional)
|
||||
if perturb_snr:
|
||||
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
|
||||
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
|
||||
if i == 0:
|
||||
# Skip the first track since it is the reference
|
||||
continue
|
||||
track.snr = snr
|
||||
|
||||
# Reverberate the cut (optional)
|
||||
if rirs is not None:
|
||||
# Select an RIR at random
|
||||
rir = random.choice(rirs)
|
||||
# Select a channel at random
|
||||
rir_channel = random.choice(list(range(rir.num_channels)))
|
||||
# Reverberate the cut
|
||||
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
|
||||
|
||||
# Add noise (optional)
|
||||
if noises is not None:
|
||||
# Select a noise recording at random
|
||||
noise = random.choice(noises).to_cut()
|
||||
if isinstance(noise, MultiCut):
|
||||
noise = noise.to_mono()[0]
|
||||
# Select an SNR at random
|
||||
snr = random.uniform(10, 30)
|
||||
# Repeat the noise to match the duration of the cut
|
||||
noise = repeat_cut(noise, out_cut.duration)
|
||||
out_cut = MixedCut(
|
||||
id=out_cut.id,
|
||||
tracks=[
|
||||
MixTrack(cut=out_cut, type="MixedCut"),
|
||||
MixTrack(cut=noise, type="DataCut", snr=snr),
|
||||
],
|
||||
)
|
||||
|
||||
# Perturb the loudness (optional)
|
||||
if perturb_loudness:
|
||||
target_loudness = random.uniform(-20, -25)
|
||||
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
|
||||
return out_cut
|
||||
|
||||
|
||||
def repeat_cut(cut, duration):
|
||||
while cut.duration < duration:
|
||||
cut = cut.mix(cut, offset_other_by=cut.duration)
|
||||
return cut.truncate(duration=duration)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
fix_random_seed(42)
|
||||
compute_fbank_aimix()
|
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the AMI dataset.
|
||||
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||
This way we can create arbitrary segmentations later.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_ami():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
manifests[part] = read_manifests_if_cached(
|
||||
dataset_parts=["train", "dev", "test"],
|
||||
output_dir=src_dir,
|
||||
prefix=f"ami-{part}",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
for split in ["train", "dev", "test"]:
|
||||
logging.info(f"Processing {part} {split}")
|
||||
cuts = CutSet.from_manifests(
|
||||
**manifests[part][split]
|
||||
).compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"ami-{part}_{split}_feats",
|
||||
manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ami()
|
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the ICSI dataset.
|
||||
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||
This way we can create arbitrary segmentations later.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_icsi():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for part in ["ihm-mix", "sdm"]:
|
||||
manifests[part] = read_manifests_if_cached(
|
||||
dataset_parts=["train"],
|
||||
output_dir=src_dir,
|
||||
prefix=f"icsi-{part}",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
for part in ["ihm-mix", "sdm"]:
|
||||
for split in ["train"]:
|
||||
logging.info(f"Processing {part} {split}")
|
||||
cuts = CutSet.from_manifests(
|
||||
**manifests[part][split]
|
||||
).compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"icsi-{part}_{split}_feats",
|
||||
manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_icsi()
|
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the trimmed sub-segments which will be
|
||||
used for simulating the training mixtures.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torchaudio
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
set_ffmpeg_torchaudio_info_enabled(False)
|
||||
|
||||
|
||||
def compute_fbank_ihm():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for data in ["ami", "icsi"]:
|
||||
manifests[data] = read_manifests_if_cached(
|
||||
dataset_parts=["train"],
|
||||
output_dir=src_dir,
|
||||
types=["recordings", "supervisions"],
|
||||
prefix=f"{data}-ihm",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
for data in ["ami", "icsi"]:
|
||||
cs = CutSet.from_manifests(**manifests[data]["train"])
|
||||
cs = cs.trim_to_supervisions(keep_overlapping=False)
|
||||
cs = cs.normalize_loudness(target=-23.0, affix_id=False)
|
||||
cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1)
|
||||
_ = cs.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"{data}-ihm_train_feats",
|
||||
manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ihm()
|
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file creates AMI train segments.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import LilcomChunkyWriter, load_manifest_lazy
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.utils import EPSILON, add_durations
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def cut_into_windows(cuts: CutSet, duration: float):
|
||||
"""
|
||||
This function takes a CutSet and cuts each cut into windows of roughly
|
||||
`duration` seconds. By roughly, we mean that we try to adjust for the last supervision
|
||||
that exceeds the duration, or is shorter than the duration.
|
||||
"""
|
||||
res = []
|
||||
with tqdm() as pbar:
|
||||
for cut in cuts:
|
||||
pbar.update(1)
|
||||
sups = cut.index_supervisions()[cut.id]
|
||||
sr = cut.sampling_rate
|
||||
start = 0.0
|
||||
end = duration
|
||||
num_tries = 0
|
||||
while start < cut.duration and num_tries < 2:
|
||||
# Find the supervision that are cut by the window endpoint
|
||||
hitlist = [iv for iv in sups.at(end) if iv.begin < end]
|
||||
# If there are no supervisions, we are done
|
||||
if not hitlist:
|
||||
res.append(
|
||||
cut.truncate(
|
||||
offset=start,
|
||||
duration=add_durations(end, -start, sampling_rate=sr),
|
||||
keep_excessive_supervisions=False,
|
||||
)
|
||||
)
|
||||
# Update the start and end for the next window
|
||||
start = end
|
||||
end = add_durations(end, duration, sampling_rate=sr)
|
||||
else:
|
||||
# find ratio of durations cut by the window endpoint
|
||||
ratios = [
|
||||
add_durations(end, -iv.end, sampling_rate=sr) / iv.length()
|
||||
for iv in hitlist
|
||||
]
|
||||
# we retain the supervisions that have >50% of their duration
|
||||
# in the window, and discard the others
|
||||
retained = []
|
||||
discarded = []
|
||||
for iv, ratio in zip(hitlist, ratios):
|
||||
if ratio > 0.5:
|
||||
retained.append(iv)
|
||||
else:
|
||||
discarded.append(iv)
|
||||
cur_end = max(iv.end for iv in retained) if retained else end
|
||||
res.append(
|
||||
cut.truncate(
|
||||
offset=start,
|
||||
duration=add_durations(cur_end, -start, sampling_rate=sr),
|
||||
keep_excessive_supervisions=False,
|
||||
)
|
||||
)
|
||||
# For the next window, we start at the earliest discarded supervision
|
||||
next_start = min(iv.begin for iv in discarded) if discarded else end
|
||||
next_end = add_durations(next_start, duration, sampling_rate=sr)
|
||||
# It may happen that next_start is the same as start, in which case
|
||||
# we will advance the window anyway
|
||||
if next_start == start:
|
||||
logging.warning(
|
||||
f"Next start is the same as start: {next_start} == {start} for cut {cut.id}"
|
||||
)
|
||||
start = end + EPSILON
|
||||
end = add_durations(start, duration, sampling_rate=sr)
|
||||
num_tries += 1
|
||||
else:
|
||||
start = next_start
|
||||
end = next_end
|
||||
return CutSet.from_cuts(res)
|
||||
|
||||
|
||||
def prepare_train_cuts():
|
||||
src_dir = Path("data/manifests")
|
||||
|
||||
logging.info("Loading the manifests")
|
||||
train_cuts_ihm = load_manifest_lazy(
|
||||
src_dir / "cuts_ami-ihm-mix_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map(
|
||||
lambda c: c.with_id(f"{c.id}_sdm")
|
||||
)
|
||||
train_cuts_mdm = load_manifest_lazy(
|
||||
src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_mdm8-bf"))
|
||||
|
||||
# Combine all cuts into one CutSet
|
||||
train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm
|
||||
|
||||
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||
|
||||
# Combine the two segmentations
|
||||
train_all = train_cuts_1 + train_cuts_2
|
||||
|
||||
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||
# roughly 30 seconds.
|
||||
logging.info("Cutting the segments into windows of 30 seconds")
|
||||
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||
|
||||
# Show statistics
|
||||
train_all.describe(full=True)
|
||||
|
||||
# Save the cuts
|
||||
logging.info("Saving the cuts")
|
||||
train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_train_cuts()
|
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file creates ICSI train segments.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
from prepare_ami_train_cuts import cut_into_windows
|
||||
|
||||
|
||||
def prepare_train_cuts():
|
||||
src_dir = Path("data/manifests")
|
||||
|
||||
logging.info("Loading the manifests")
|
||||
train_cuts_ihm = load_manifest_lazy(
|
||||
src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map(
|
||||
lambda c: c.with_id(f"{c.id}_sdm")
|
||||
)
|
||||
|
||||
# Combine all cuts into one CutSet
|
||||
train_cuts = train_cuts_ihm + train_cuts_sdm
|
||||
|
||||
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||
|
||||
# Combine the two segmentations
|
||||
train_all = train_cuts_1 + train_cuts_2
|
||||
|
||||
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||
# roughly 30 seconds.
|
||||
logging.info("Cutting the segments into windows of 30 seconds")
|
||||
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||
|
||||
# Show statistics
|
||||
train_all.describe(full=True)
|
||||
|
||||
# Save the cuts
|
||||
logging.info("Saving the cuts")
|
||||
train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_train_cuts()
|
1
egs/ami/SURT/local/prepare_lang_bpe.py
Symbolic link
1
egs/ami/SURT/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
1
egs/ami/SURT/local/train_bpe_model.py
Symbolic link
1
egs/ami/SURT/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
195
egs/ami/SURT/prepare.sh
Executable file
195
egs/ami/SURT/prepare.sh
Executable file
@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/ami
|
||||
# You can find audio and transcripts for AMI in this path.
|
||||
#
|
||||
# - $dl_dir/icsi
|
||||
# You can find audio and transcripts for ICSI in this path.
|
||||
#
|
||||
# - $dl_dir/rirs_noises
|
||||
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
|
||||
#
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
vocab_size=500
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/amicorpus,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/amicorpus $dl_dir/amicorpus
|
||||
#
|
||||
if [ ! -d $dl_dir/amicorpus ]; then
|
||||
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||
lhotse download ami --mic $mic $dl_dir/amicorpus
|
||||
done
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/icsi,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/icsi $dl_dir/icsi
|
||||
#
|
||||
if [ ! -d $dl_dir/icsi ]; then
|
||||
lhotse download icsi $dl_dir/icsi
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/rirs_noises,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/rirs_noises $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/rirs_noises ]; then
|
||||
lhotse download rirs_noises $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare AMI manifests"
|
||||
# We assume that you have downloaded the AMI corpus
|
||||
# to $dl_dir/amicorpus. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||
log "Preparing AMI manifest for $mic"
|
||||
lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare ICSI manifests"
|
||||
# We assume that you have downloaded the ICSI corpus
|
||||
# to $dl_dir/icsi. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
log "Preparing ICSI manifest"
|
||||
for mic in ihm ihm-mix sdm; do
|
||||
lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare RIRs"
|
||||
# We assume that you have downloaded the RIRS_NOISES corpus
|
||||
# to $dl_dir/rirs_noises
|
||||
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 3: Extract features for AMI and ICSI recordings"
|
||||
python local/compute_fbank_ami.py
|
||||
python local/compute_fbank_icsi.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Create sources for simulating mixtures"
|
||||
# In the following script, we speed-perturb the IHM recordings and extract features.
|
||||
python local/compute_fbank_ihm.py
|
||||
lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \
|
||||
data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\
|
||||
lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\
|
||||
lhotse filter 'duration<=12.0' - - |\
|
||||
shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Create training mixtures"
|
||||
lhotse workflows simulate-meetings \
|
||||
--method conversational \
|
||||
--same-spk-pause 0.5 \
|
||||
--diff-spk-pause 0.5 \
|
||||
--diff-spk-overlap 1.0 \
|
||||
--prob-diff-spk-overlap 0.8 \
|
||||
--num-meetings 200000 \
|
||||
--num-speakers-per-meeting 2,3 \
|
||||
--max-duration-per-speaker 15.0 \
|
||||
--max-utterances-per-speaker 3 \
|
||||
--seed 1234 \
|
||||
--num-jobs 2 \
|
||||
data/manifests/ihm_cuts_train_trimmed.jsonl.gz \
|
||||
data/manifests/ai-mix_cuts_clean.jsonl.gz
|
||||
|
||||
python local/compute_fbank_aimix.py
|
||||
|
||||
# Add source features to the manifest (will be used for masking loss)
|
||||
# This may take ~2 hours.
|
||||
python local/add_source_feats.py
|
||||
|
||||
# Combine clean and reverb
|
||||
cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Create training mixtures from real sessions"
|
||||
python local/prepare_ami_train_cuts.py
|
||||
python local/prepare_icsi_train_cuts.py
|
||||
|
||||
# Combine AMI and ICSI
|
||||
cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)."
|
||||
mkdir -p data/lm
|
||||
cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||
<(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||
> data/lm/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)"
|
||||
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
# Add special words to words.txt
|
||||
echo "<eps> 0" > $lang_dir/words.txt
|
||||
echo "!SIL 1" >> $lang_dir/words.txt
|
||||
echo "<UNK> 2" >> $lang_dir/words.txt
|
||||
|
||||
# Add regular words to words.txt
|
||||
cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
|
||||
|
||||
# Add remaining special word symbols expected by LM scripts.
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "<s> ${num_words}" >> $lang_dir/words.txt
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "</s> ${num_words}" >> $lang_dir/words.txt
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "#0 ${num_words}" >> $lang_dir/words.txt
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript data/lm/transcript_words.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
fi
|
1
egs/ami/SURT/shared
Symbolic link
1
egs/ami/SURT/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
249
egs/libricss/SURT/README.md
Normal file
249
egs/libricss/SURT/README.md
Normal file
@ -0,0 +1,249 @@
|
||||
# Introduction
|
||||
|
||||
This is a multi-talker ASR recipe for the LibriCSS dataset. We train a Streaming
|
||||
Unmixing and Recognition Transducer (SURT) model for the task. In this README,
|
||||
we will describe the task, the model, and the training process. We will also
|
||||
provide links to pre-trained models and training logs.
|
||||
|
||||
## Task
|
||||
|
||||
LibriCSS is a multi-talker meeting corpus formed from mixing together LibriSpeech utterances
|
||||
and replaying in a real meeting room. It consists of 10 1-hour sessions of audio, each
|
||||
recorded on a 7-channel microphone. The sessions are recorded at a sampling rate of 16 kHz.
|
||||
For more information, refer to the paper:
|
||||
Z. Chen et al., "Continuous speech separation: dataset and analysis,"
|
||||
ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),
|
||||
Barcelona, Spain, 2020
|
||||
|
||||
In this recipe, we perform the "continuous, streaming, multi-talker ASR" task on LibriCSS.
|
||||
|
||||
* By "continuous", we mean that the model should be able to transcribe unsegmented audio
|
||||
without the need of an external VAD.
|
||||
* By "streaming", we mean that the model has limited right context. We use a right-context
|
||||
of at most 32 frames (320 ms).
|
||||
* By "multi-talker", we mean that the model should be able to transcribe overlapping speech
|
||||
from multiple speakers.
|
||||
|
||||
For now, we do not care about speaker attribution, i.e., the transcription is speaker
|
||||
agnostic. The evaluation depends on the particular model type. In this case, we use
|
||||
the optimal reference combination WER (ORC-WER) metric as implemented in the
|
||||
[meeteval](https://github.com/fgnt/meeteval) toolkit.
|
||||
|
||||
## Model
|
||||
|
||||
We use the Streaming Unmixing and Recognition Transducer (SURT) model for this task.
|
||||
The model is based on the papers:
|
||||
|
||||
- Lu, Liang et al. “Streaming End-to-End Multi-Talker Speech Recognition.” IEEE Signal Processing Letters 28 (2020): 803-807.
|
||||
- Raj, Desh et al. “Continuous Streaming Multi-Talker ASR with Dual-Path Transducers.” ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2021): 7317-7321.
|
||||
|
||||
The model is a combination of a speech separation model and a speech recognition model,
|
||||
but trained end-to-end with a single loss function. The overall architecture is shown
|
||||
in the figure below. Note that this architecture is slightly different from the one
|
||||
in the above papers. A detailed description of the model can be found in the following
|
||||
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
|
||||
|
||||
<p align="center">
|
||||
|
||||
<img src="surt.png">
|
||||
Streaming Unmixing and Recognition Transducer
|
||||
|
||||
</p>
|
||||
|
||||
In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
|
||||
and a Zipfomer-based recognition network. But other combinations are possible as well.
|
||||
|
||||
## Training objective
|
||||
|
||||
We train the model using the pruned transducer loss, similar to other ASR recipes in
|
||||
icefall. However, an important consideration is how to assign references to the output
|
||||
channels (2 in this case). For this, we use the heuristic error assignment training (HEAT)
|
||||
strategy, which assigns references to the first available channel based on their start
|
||||
times. An illustrative example is shown in the figure below:
|
||||
|
||||
<p align="center">
|
||||
|
||||
<img src="heat.png">
|
||||
Illustration of HEAT-based reference assignment.
|
||||
|
||||
</p>
|
||||
|
||||
## Description of the recipe
|
||||
|
||||
### Pre-requisites
|
||||
|
||||
The recipes in this directory need the following packages to be installed:
|
||||
|
||||
- [meeteval](https://github.com/fgnt/meeteval)
|
||||
- [einops](https://github.com/arogozhnikov/einops)
|
||||
|
||||
Additionally, we initialize the "recognition" transducer with a pre-trained model,
|
||||
trained on LibriSpeech. For this, please run the following from within `egs/librispeech/ASR`:
|
||||
|
||||
```bash
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
python pruned_transducer_stateless7_streaming/train.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir pruned_transducer_stateless7_streaming/exp \
|
||||
--world-size 4 \
|
||||
--max-duration 800 \
|
||||
--num-epochs 10 \
|
||||
--keep-last-k 1 \
|
||||
--manifest-dir data/manifests \
|
||||
--enable-musan true \
|
||||
--master-port 54321 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--num-encoder-layers 2,2,2,2,2 \
|
||||
--feedforward-dims 768,768,768,768,768 \
|
||||
--nhead 8,8,8,8,8 \
|
||||
--encoder-dims 256,256,256,256,256 \
|
||||
--attention-dims 192,192,192,192,192 \
|
||||
--encoder-unmasked-dims 192,192,192,192,192 \
|
||||
--zipformer-downsampling-factors 1,2,4,8,2 \
|
||||
--cnn-module-kernels 31,31,31,31,31 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512
|
||||
```
|
||||
|
||||
The above is for SURT-base (~26M). For SURT-large (~38M), use `--num-encoder-layers 2,4,3,2,4`.
|
||||
|
||||
Once the above model is trained for 10 epochs, copy it to `egs/libricss/SURT/exp`:
|
||||
|
||||
```bash
|
||||
cp -r pruned_transducer_stateless7_streaming/exp/epoch-10.pt exp/zipformer_base.pt
|
||||
```
|
||||
|
||||
**NOTE:** We also provide this pre-trained checkpoint (see the section below), so you can skip
|
||||
the above step if you want.
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, run the following from within `egs/libricss/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
python dprnn_zipformer/train.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 30 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--heat-loss-scale 0.2 \
|
||||
--base-lr 0.004 \
|
||||
--model-init-ckpt exp/zipformer_base.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
The above is for SURT-base (~26M). For SURT-large (~38M), use:
|
||||
|
||||
```bash
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt exp/zipformer_large.pt \
|
||||
```
|
||||
|
||||
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
|
||||
|
||||
### Adaptation
|
||||
|
||||
The training step above only trains on simulated mixtures. For best results, we also
|
||||
adapt the final model on the LibriCSS dev set. For this, run the following from within
|
||||
`egs/libricss/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/train_adapt.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--world-size 1 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 8 \
|
||||
--lr-epochs 2 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--base-lr 0.0004 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
For SURT-large, use the following config:
|
||||
|
||||
```bash
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
|
||||
--num-epochs 15 \
|
||||
--lr-epochs 4 \
|
||||
```
|
||||
|
||||
|
||||
### Decoding
|
||||
|
||||
To decode the model, run the following from within `egs/libricss/SURT`:
|
||||
|
||||
#### Greedy search
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 8 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
|
||||
#### Beam search
|
||||
|
||||
```bash
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 8 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
## Results (using beam search)
|
||||
|
||||
#### IHM-Mix
|
||||
|
||||
| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
|
||||
|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
|
||||
| dprnn_zipformer (base) | 26.7 | 5.1 | 4.2 | 13.7 | 18.7 | 20.5 | 20.6 | 13.8 |
|
||||
| dprnn_zipformer (large) | 37.9 | 4.6 | 3.8 | 12.7 | 14.3 | 16.7 | 21.2 | 12.2 |
|
||||
|
||||
#### SDM
|
||||
|
||||
| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
|
||||
|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
|
||||
| dprnn_zipformer (base) | 26.7 | 6.8 | 7.2 | 21.4 | 24.5 | 28.6 | 31.2 | 20.0 |
|
||||
| dprnn_zipformer (large) | 37.9 | 6.4 | 6.9 | 17.9 | 19.7 | 25.2 | 25.5 | 16.9 |
|
||||
|
||||
## Pre-trained models and logs
|
||||
|
||||
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer>
|
||||
|
||||
* Training logs:
|
||||
- surt_base: <https://tensorboard.dev/experiment/YLGJTkBETb2aqDQ61jbxvQ/>
|
||||
- surt_base_adapt: <https://tensorboard.dev/experiment/pjXMFVL9RMej85rMHyd0EQ/>
|
||||
- surt_large: <https://tensorboard.dev/experiment/82HvYqfrSOKZ4w8Jod2QMw/>
|
||||
- surt_large_adapt: <https://tensorboard.dev/experiment/5oIdEgRqS9Wb6yVuxaExEw/>
|
372
egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
372
egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,372 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
# Copyright 2023 Johns Hopkins Univrtsity (Author: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SurtDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriCssAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration-valid",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-cuts",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of cuts in a single batch. You can "
|
||||
"reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help=(
|
||||
"When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
return_sources: bool = True,
|
||||
strict: bool = True,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=return_sources,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
validate = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def lsmix_cuts(
|
||||
self,
|
||||
rvb_affix: str = "clean",
|
||||
type_affix: str = "full",
|
||||
sources: bool = True,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
source_affix = "_sources" if sources else ""
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir
|
||||
/ f"cuts_train_{rvb_affix}_{type_affix}{source_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
|
||||
logging.info(f"About to get LibriCSS {split} {type} cuts")
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
|
||||
)
|
||||
return cs
|
730
egs/libricss/SURT/dprnn_zipformer/beam_search.py
Normal file
730
egs/libricss/SURT/dprnn_zipformer/beam_search.py
Normal file
@ -0,0 +1,730 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from model import SURT
|
||||
|
||||
from icefall import NgramLmStateCost
|
||||
from icefall.utils import DecodingResults
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: SURT,
|
||||
encoder_out: torch.Tensor,
|
||||
max_sym_per_frame: int,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""Greedy search for a single utterance.
|
||||
Args:
|
||||
model:
|
||||
An instance of `SURT`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
max_sym_per_frame:
|
||||
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||
would be 100%.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 4
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
hyp = [blank_id] * context_size
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which hyp[i] is decoded
|
||||
timestamp = []
|
||||
|
||||
# Maximum symbols per utterance.
|
||||
max_sym_per_utt = 1000
|
||||
|
||||
# symbols per frame
|
||||
sym_per_frame = 0
|
||||
|
||||
# symbols per utterance decoded so far
|
||||
sym_per_utt = 0
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
if sym_per_frame >= max_sym_per_frame:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
if y not in (blank_id, unk_id):
|
||||
hyp.append(y)
|
||||
timestamp.append(t)
|
||||
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
sym_per_utt += 1
|
||||
sym_per_frame += 1
|
||||
else:
|
||||
sym_per_frame = 0
|
||||
t += 1
|
||||
hyp = hyp[context_size:] # remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return hyp
|
||||
else:
|
||||
return DecodingResults(
|
||||
hyps=[hyp],
|
||||
timestamps=[timestamp],
|
||||
)
|
||||
|
||||
|
||||
def greedy_search_batch(
|
||||
model: SURT,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The SURT model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
|
||||
|
||||
# timestamp[n][i] is the frame index after subsampling
|
||||
# on which hyp[n][i] is decoded
|
||||
timestamps = [[] for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out: (N, 1, decoder_out_dim)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||
)
|
||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v not in (blank_id, unk_id):
|
||||
hyps[i].append(v)
|
||||
timestamps[i].append(t)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
hyps=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: SURT,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[List[int]], DecodingResults]:
|
||||
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The SURT model.
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
beam:
|
||||
Number of active paths during the beam search.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.shape
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
project_input=False,
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
hyp_idx = topk_hyp_indexes[k]
|
||||
hyp = A[i][hyp_idx]
|
||||
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[k]
|
||||
new_timestamp = hyp.timestamp[:]
|
||||
if new_token not in (blank_id, unk_id):
|
||||
new_ys.append(new_token)
|
||||
new_timestamp.append(t)
|
||||
|
||||
new_log_prob = topk_log_probs[k]
|
||||
new_hyp = Hypothesis(
|
||||
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
B = B + finalized_B
|
||||
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||
|
||||
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||
ans = []
|
||||
ans_timestamps = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||
|
||||
if not return_timestamps:
|
||||
return ans
|
||||
else:
|
||||
return DecodingResults(
|
||||
hyps=ans,
|
||||
timestamps=ans_timestamps,
|
||||
)
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: SURT,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
) -> Union[List[int], DecodingResults]:
|
||||
"""
|
||||
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
|
||||
|
||||
espnet/nets/beam_search_SURT.py#L247 is used as a reference.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `SURT`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
temperature:
|
||||
Softmax temperature.
|
||||
return_timestamps:
|
||||
Whether to return timestamps.
|
||||
|
||||
Returns:
|
||||
If return_timestamps is False, return the decoded result.
|
||||
Else, return a DecodingResults object containing
|
||||
decoded result and corresponding timestamps.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
sym_per_utt = 0
|
||||
|
||||
decoder_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
while t < T and sym_per_utt < max_sym_per_utt:
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
A = B
|
||||
B = HypothesisList()
|
||||
|
||||
joint_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
||||
# of hypotheses in A
|
||||
|
||||
while True:
|
||||
y_star = A.get_most_probable()
|
||||
A.remove(y_star)
|
||||
|
||||
cached_key = y_star.key
|
||||
|
||||
if cached_key not in decoder_cache:
|
||||
decoder_input = torch.tensor(
|
||||
[y_star.ys[-context_size:]],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
decoder_cache[cached_key] = decoder_out
|
||||
else:
|
||||
decoder_out = decoder_cache[cached_key]
|
||||
|
||||
cached_key += f"-t-{t}"
|
||||
if cached_key not in joint_cache:
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
|
||||
# TODO(fangjun): Scale the blank posterior
|
||||
log_prob = (logits / temperature).log_softmax(dim=-1)
|
||||
# log_prob is (1, 1, 1, vocab_size)
|
||||
log_prob = log_prob.squeeze()
|
||||
# Now log_prob is (vocab_size,)
|
||||
joint_cache[cached_key] = log_prob
|
||||
else:
|
||||
log_prob = joint_cache[cached_key]
|
||||
|
||||
# First, process the blank symbol
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=y_star.ys[:],
|
||||
log_prob=new_y_star_log_prob,
|
||||
timestamp=y_star.timestamp[:],
|
||||
)
|
||||
)
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
for i, v in zip(indices.tolist(), values.tolist()):
|
||||
if i in (blank_id, unk_id):
|
||||
continue
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
new_timestamp = y_star.timestamp + [t]
|
||||
A.add(
|
||||
Hypothesis(
|
||||
ys=new_ys,
|
||||
log_prob=new_log_prob,
|
||||
timestamp=new_timestamp,
|
||||
)
|
||||
)
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
# than the most probable in A
|
||||
A_most_probable = A.get_most_probable()
|
||||
|
||||
kept_B = B.filter(A_most_probable.log_prob)
|
||||
|
||||
if len(kept_B) >= beam:
|
||||
B = kept_B.topk(beam)
|
||||
break
|
||||
|
||||
t += 1
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
if not return_timestamps:
|
||||
return ys
|
||||
else:
|
||||
return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hypothesis:
|
||||
# The predicted tokens so far.
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
# timestamp[i] is the frame index after subsampling
|
||||
# on which ys[i] is decoded
|
||||
timestamp: List[int] = field(default_factory=list)
|
||||
|
||||
# the lm score for next token given the current ys
|
||||
lm_score: Optional[torch.Tensor] = None
|
||||
|
||||
# the RNNLM states (h and c in LSTM)
|
||||
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
||||
|
||||
# N-gram LM state
|
||||
state_cost: Optional[NgramLmStateCost] = None
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Return a string representation of self.ys"""
|
||||
return "_".join(map(str, self.ys))
|
||||
|
||||
|
||||
class HypothesisList(object):
|
||||
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
data:
|
||||
A dict of Hypotheses. Its key is its `value.key`.
|
||||
"""
|
||||
if data is None:
|
||||
self._data = {}
|
||||
else:
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Hypothesis]:
|
||||
return self._data
|
||||
|
||||
def add(self, hyp: Hypothesis) -> None:
|
||||
"""Add a Hypothesis to `self`.
|
||||
|
||||
If `hyp` already exists in `self`, its probability is updated using
|
||||
`log-sum-exp` with the existed one.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be added.
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||
"""Get the most probable hypothesis, i.e., the one with
|
||||
the largest `log_prob`.
|
||||
|
||||
Args:
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
Returns:
|
||||
Return the hypothesis that has the largest `log_prob`.
|
||||
"""
|
||||
if length_norm:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
|
||||
else:
|
||||
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||
|
||||
def remove(self, hyp: Hypothesis) -> None:
|
||||
"""Remove a given hypothesis.
|
||||
|
||||
Caution:
|
||||
`self` is modified **in-place**.
|
||||
|
||||
Args:
|
||||
hyp:
|
||||
The hypothesis to be removed from `self`.
|
||||
Note: It must be contained in `self`. Otherwise,
|
||||
an exception is raised.
|
||||
"""
|
||||
key = hyp.key
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||
|
||||
Returns:
|
||||
Return a new HypothesisList containing all hypotheses from `self`
|
||||
with `log_prob` being greater than the given `threshold`.
|
||||
"""
|
||||
ans = HypothesisList()
|
||||
for _, hyp in self._data.items():
|
||||
if hyp.log_prob > threshold:
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
|
||||
def topk(self, k: int) -> "HypothesisList":
|
||||
"""Return the top-k hypothesis."""
|
||||
hyps = list(self._data.items())
|
||||
|
||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||
|
||||
ans = HypothesisList(dict(hyps))
|
||||
return ans
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self._data
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = []
|
||||
for key in self:
|
||||
s.append(key)
|
||||
return ", ".join(s)
|
||||
|
||||
|
||||
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
"""Return a ragged shape with axes [utt][num_hyps].
|
||||
|
||||
Args:
|
||||
hyps:
|
||||
len(hyps) == batch_size. It contains the current hypothesis for
|
||||
each utterance in the batch.
|
||||
Returns:
|
||||
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
|
||||
the shape is on CPU.
|
||||
"""
|
||||
num_hyps = [len(h) for h in hyps]
|
||||
|
||||
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
|
||||
# to get exclusive sum later.
|
||||
num_hyps.insert(0, 0)
|
||||
|
||||
num_hyps = torch.tensor(num_hyps)
|
||||
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
|
||||
ans = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
|
||||
)
|
||||
return ans
|
654
egs/libricss/SURT/dprnn_zipformer/decode.py
Executable file
654
egs/libricss/SURT/dprnn_zipformer/decode.py
Executable file
@ -0,0 +1,654 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--use-averaged-model true \
|
||||
--exp-dir ./dprnn_zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) modified beam search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--use-averaged-model true \
|
||||
--exp-dir ./dprnn_zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriCssAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.utils import EPSILON
|
||||
from train import add_model_arguments, get_params, get_surt_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_surt_error_stats,
|
||||
)
|
||||
|
||||
OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="dprnn_zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-masks",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If true, save masks generated by unmixing module.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
feature_lens = batch["input_lens"].to(device)
|
||||
|
||||
# Apply the mask encoder
|
||||
B, T, F = feature.shape
|
||||
processed = model.mask_encoder(feature) # B,T,F*num_channels
|
||||
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
||||
x_masked = [feature * m for m in masks]
|
||||
|
||||
masks_dict = {}
|
||||
if params.save_masks:
|
||||
# To save the masks, we split them by batch and trim each mask to the length of
|
||||
# the corresponding feature. We save them in a dict, where the key is the
|
||||
# cut ID and the value is the mask.
|
||||
for i in range(B):
|
||||
mask = torch.cat(
|
||||
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
|
||||
dim=-1,
|
||||
)
|
||||
mask = mask.cpu().numpy()
|
||||
masks_dict[batch["cuts"][i].id] = mask
|
||||
|
||||
# Recognition
|
||||
# Concatenate the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = feature_lens.repeat(params.num_channels)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||
|
||||
if model.joint_encoder_layer is not None:
|
||||
encoder_out = model.joint_encoder_layer(encoder_out)
|
||||
|
||||
def _group_channels(hyps: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Currently we have a batch of size M*B, where M is the number of
|
||||
channels and B is the batch size. We need to group the hypotheses
|
||||
into B groups, each of which contains M hypotheses.
|
||||
|
||||
Example:
|
||||
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
|
||||
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
|
||||
"""
|
||||
assert len(hyps) == B * params.num_channels
|
||||
out_hyps = []
|
||||
for i in range(B):
|
||||
out_hyps.append(hyps[i::B])
|
||||
return out_hyps
|
||||
|
||||
hyps = []
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp))
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": _group_channels(hyps)}, masks_dict
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}, masks_dict
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
masks = {}
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||
cuts_batch = batch["cuts"]
|
||||
|
||||
hyps_dict, masks_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
masks.update(masks_dict)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
for cut_id, hyp_words in zip(cut_ids, hyps):
|
||||
# Reference is a list of supervision texts sorted by start time.
|
||||
ref_words = [
|
||||
s.text.strip()
|
||||
for s in sorted(
|
||||
cuts_batch[cut_id].supervisions, key=lambda s: s.start
|
||||
)
|
||||
]
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(cut_ids)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results, masks_dict
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_surt_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
num_channels=params.num_channels,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
def save_masks(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
masks: List[torch.Tensor],
|
||||
):
|
||||
masks_path = params.res_dir / f"masks-{test_set_name}.txt"
|
||||
torch.save(masks, masks_path)
|
||||
logging.info(f"The masks are stored in {masks_path}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LmScorer.add_arguments(parser)
|
||||
LibriCssAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
), f"Decoding method {params.decoding_method} is not supported."
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_surt_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
libricss = LibriCssAsrDataModule(args)
|
||||
|
||||
dev_cuts = libricss.libricss_cuts(split="dev", type="ihm-mix").to_eager()
|
||||
dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS]
|
||||
test_cuts = libricss.libricss_cuts(split="test", type="ihm-mix").to_eager()
|
||||
test_cuts_grouped = [
|
||||
test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS
|
||||
]
|
||||
|
||||
for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS):
|
||||
dev_dl = libricss.test_dataloaders(dev_set)
|
||||
results_dict, masks = decode_dataset(
|
||||
dl=dev_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=f"dev_{ol}",
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if params.save_masks:
|
||||
save_masks(
|
||||
params=params,
|
||||
test_set_name=f"dev_{ol}",
|
||||
masks=masks,
|
||||
)
|
||||
|
||||
for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
|
||||
test_dl = libricss.test_dataloaders(test_set)
|
||||
results_dict, masks = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=f"test_{ol}",
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if params.save_masks:
|
||||
save_masks(
|
||||
params=params,
|
||||
test_set_name=f"test_{ol}",
|
||||
masks=masks,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/libricss/SURT/dprnn_zipformer/decoder.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
|
305
egs/libricss/SURT/dprnn_zipformer/dprnn.py
Normal file
305
egs/libricss/SURT/dprnn_zipformer/dprnn.py
Normal file
@ -0,0 +1,305 @@
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from scaling import ActivationBalancer, BasicNorm, DoubleSwish, ScaledLinear, ScaledLSTM
|
||||
from torch.autograd import Variable
|
||||
|
||||
EPS = torch.finfo(torch.get_default_dtype()).eps
|
||||
|
||||
|
||||
def _pad_segment(input, segment_size):
|
||||
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L342
|
||||
# input is the features: (B, N, T)
|
||||
batch_size, dim, seq_len = input.shape
|
||||
segment_stride = segment_size // 2
|
||||
|
||||
rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
|
||||
if rest > 0:
|
||||
pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
|
||||
input = torch.cat([input, pad], 2)
|
||||
|
||||
pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type())
|
||||
input = torch.cat([pad_aux, input, pad_aux], 2)
|
||||
|
||||
return input, rest
|
||||
|
||||
|
||||
def split_feature(input, segment_size):
|
||||
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L358
|
||||
# split the feature into chunks of segment size
|
||||
# input is the features: (B, N, T)
|
||||
|
||||
input, rest = _pad_segment(input, segment_size)
|
||||
batch_size, dim, seq_len = input.shape
|
||||
segment_stride = segment_size // 2
|
||||
|
||||
segments1 = (
|
||||
input[:, :, :-segment_stride]
|
||||
.contiguous()
|
||||
.view(batch_size, dim, -1, segment_size)
|
||||
)
|
||||
segments2 = (
|
||||
input[:, :, segment_stride:]
|
||||
.contiguous()
|
||||
.view(batch_size, dim, -1, segment_size)
|
||||
)
|
||||
segments = (
|
||||
torch.cat([segments1, segments2], 3)
|
||||
.view(batch_size, dim, -1, segment_size)
|
||||
.transpose(2, 3)
|
||||
)
|
||||
|
||||
return segments.contiguous(), rest
|
||||
|
||||
|
||||
def merge_feature(input, rest):
|
||||
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L385
|
||||
# merge the splitted features into full utterance
|
||||
# input is the features: (B, N, L, K)
|
||||
|
||||
batch_size, dim, segment_size, _ = input.shape
|
||||
segment_stride = segment_size // 2
|
||||
input = (
|
||||
input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2)
|
||||
) # B, N, K, L
|
||||
|
||||
input1 = (
|
||||
input[:, :, :, :segment_size]
|
||||
.contiguous()
|
||||
.view(batch_size, dim, -1)[:, :, segment_stride:]
|
||||
)
|
||||
input2 = (
|
||||
input[:, :, :, segment_size:]
|
||||
.contiguous()
|
||||
.view(batch_size, dim, -1)[:, :, :-segment_stride]
|
||||
)
|
||||
|
||||
output = input1 + input2
|
||||
if rest > 0:
|
||||
output = output[:, :, :-rest]
|
||||
|
||||
return output.contiguous() # B, N, T
|
||||
|
||||
|
||||
class RNNEncoderLayer(nn.Module):
|
||||
"""
|
||||
RNNEncoderLayer is made up of lstm and feedforward networks.
|
||||
Args:
|
||||
input_size:
|
||||
The number of expected features in the input (required).
|
||||
hidden_size:
|
||||
The hidden dimension of rnn layer.
|
||||
dropout:
|
||||
The dropout value (default=0.1).
|
||||
layer_dropout:
|
||||
The dropout value for model-level warmup (default=0.075).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
dropout: float = 0.1,
|
||||
bidirectional: bool = False,
|
||||
) -> None:
|
||||
super(RNNEncoderLayer, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
assert hidden_size >= input_size, (hidden_size, input_size)
|
||||
self.lstm = ScaledLSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size // 2 if bidirectional else hidden_size,
|
||||
proj_size=0,
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
self.norm_final = BasicNorm(input_size)
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
|
||||
self.balancer = ActivationBalancer(
|
||||
num_channels=input_size,
|
||||
channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
max_abs=6.0,
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
Args:
|
||||
src:
|
||||
The sequence to the encoder layer (required).
|
||||
Its shape is (S, N, E), where S is the sequence length,
|
||||
N is the batch size, and E is the feature number.
|
||||
states:
|
||||
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||
states[0] is the hidden states of all layers,
|
||||
with shape of (1, N, input_size);
|
||||
states[1] is the cell states of all layers,
|
||||
with shape of (1, N, hidden_size).
|
||||
"""
|
||||
src_orig = src
|
||||
|
||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||
# completely bypass it.
|
||||
alpha = warmup if self.training else 1.0
|
||||
|
||||
# lstm module
|
||||
src_lstm, new_states = self.lstm(src, states)
|
||||
src = self.dropout(src_lstm) + src
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
return src
|
||||
|
||||
|
||||
# dual-path RNN
|
||||
class DPRNN(nn.Module):
|
||||
"""Deep dual-path RNN.
|
||||
Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py
|
||||
|
||||
args:
|
||||
input_size: int, dimension of the input feature. The input should have shape
|
||||
(batch, seq_len, input_size).
|
||||
hidden_size: int, dimension of the hidden state.
|
||||
output_size: int, dimension of the output size.
|
||||
dropout: float, dropout ratio. Default is 0.
|
||||
num_blocks: int, number of stacked RNN layers. Default is 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim,
|
||||
input_size,
|
||||
hidden_size,
|
||||
output_size,
|
||||
dropout=0.1,
|
||||
num_blocks=1,
|
||||
segment_size=50,
|
||||
chunk_width_randomization=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.segment_size = segment_size
|
||||
self.chunk_width_randomization = chunk_width_randomization
|
||||
|
||||
self.input_embed = nn.Sequential(
|
||||
ScaledLinear(feature_dim, input_size),
|
||||
BasicNorm(input_size),
|
||||
ActivationBalancer(
|
||||
num_channels=input_size,
|
||||
channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
),
|
||||
)
|
||||
|
||||
# dual-path RNN
|
||||
self.row_rnn = nn.ModuleList([])
|
||||
self.col_rnn = nn.ModuleList([])
|
||||
for _ in range(num_blocks):
|
||||
# intra-RNN is non-causal
|
||||
self.row_rnn.append(
|
||||
RNNEncoderLayer(
|
||||
input_size, hidden_size, dropout=dropout, bidirectional=True
|
||||
)
|
||||
)
|
||||
self.col_rnn.append(
|
||||
RNNEncoderLayer(
|
||||
input_size, hidden_size, dropout=dropout, bidirectional=False
|
||||
)
|
||||
)
|
||||
|
||||
# output layer
|
||||
self.out_embed = nn.Sequential(
|
||||
ScaledLinear(input_size, output_size),
|
||||
BasicNorm(output_size),
|
||||
ActivationBalancer(
|
||||
num_channels=output_size,
|
||||
channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
# input shape: B, T, F
|
||||
input = self.input_embed(input)
|
||||
B, T, D = input.shape
|
||||
|
||||
if self.chunk_width_randomization and self.training:
|
||||
segment_size = random.randint(self.segment_size // 2, self.segment_size)
|
||||
else:
|
||||
segment_size = self.segment_size
|
||||
input, rest = split_feature(input.transpose(1, 2), segment_size)
|
||||
# input shape: batch, N, dim1, dim2
|
||||
# apply RNN on dim1 first and then dim2
|
||||
# output shape: B, output_size, dim1, dim2
|
||||
# input = input.to(device)
|
||||
batch_size, _, dim1, dim2 = input.shape
|
||||
output = input
|
||||
for i in range(len(self.row_rnn)):
|
||||
row_input = (
|
||||
output.permute(0, 3, 2, 1)
|
||||
.contiguous()
|
||||
.view(batch_size * dim2, dim1, -1)
|
||||
) # B*dim2, dim1, N
|
||||
output = self.row_rnn[i](row_input) # B*dim2, dim1, H
|
||||
output = (
|
||||
output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
|
||||
) # B, N, dim1, dim2
|
||||
|
||||
col_input = (
|
||||
output.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
.view(batch_size * dim1, dim2, -1)
|
||||
) # B*dim1, dim2, N
|
||||
output = self.col_rnn[i](col_input) # B*dim1, dim2, H
|
||||
output = (
|
||||
output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
|
||||
) # B, N, dim1, dim2
|
||||
|
||||
output = merge_feature(output, rest)
|
||||
output = output.transpose(1, 2)
|
||||
output = self.out_embed(output)
|
||||
|
||||
# Apply ReLU to the output
|
||||
output = torch.relu(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
model = DPRNN(
|
||||
80,
|
||||
256,
|
||||
256,
|
||||
160,
|
||||
dropout=0.1,
|
||||
num_blocks=4,
|
||||
segment_size=32,
|
||||
chunk_width_randomization=True,
|
||||
)
|
||||
input = torch.randn(2, 1002, 80)
|
||||
print(sum(p.numel() for p in model.parameters()))
|
||||
print(model(input).shape)
|
1
egs/libricss/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
|
306
egs/libricss/SURT/dprnn_zipformer/export.py
Executable file
306
egs/libricss/SURT/dprnn_zipformer/export.py
Executable file
@ -0,0 +1,306 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./dprnn_zipformer/export.py \
|
||||
--exp-dir ./dprnn_zipformer/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./dprnn_zipformer/export.py \
|
||||
--exp-dir ./dprnn_zipformer/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `dprnn_zipformer/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./dprnn_zipformer/decode.py \
|
||||
--exp-dir ./dprnn_zipformer/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_params, get_surt_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="dprnn_zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_surt_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/libricss/SURT/dprnn_zipformer/joiner.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
|
316
egs/libricss/SURT/dprnn_zipformer/model.py
Normal file
316
egs/libricss/SURT/dprnn_zipformer/model.py
Normal file
@ -0,0 +1,316 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
# Copyright 2023 Johns Hopkins University (author: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class SURT(nn.Module):
|
||||
"""It implements Streaming Unmixing and Recognition Transducer (SURT).
|
||||
https://arxiv.org/abs/2011.13148
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mask_encoder: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
joint_encoder_layer: Optional[nn.Module],
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
num_channels: int,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
mask_encoder:
|
||||
It is the masking network. It generates a mask for each channel of the
|
||||
encoder. These masks are applied to the input features, and then passed
|
||||
to the transcription network.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
num_channels:
|
||||
It is the number of channels that the input features will be split into.
|
||||
In general, it should be equal to the maximum number of simultaneously
|
||||
active speakers. For most real scenarios, using 2 channels is sufficient.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.mask_encoder = mask_encoder
|
||||
self.encoder = encoder
|
||||
self.joint_encoder_layer = joint_encoder_layer
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
self.num_channels = num_channels
|
||||
|
||||
self.simple_am_proj = nn.Linear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
)
|
||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
def forward_helper(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
reduction: str = "sum",
|
||||
beam_size: int = 10,
|
||||
use_double_scores: bool = False,
|
||||
subsampling_factor: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute transducer loss for one branch of the SURT model.
|
||||
"""
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
if self.joint_encoder_layer is not None:
|
||||
encoder_out = self.joint_encoder_layer(encoder_out)
|
||||
|
||||
# compute ctc log-probs
|
||||
ctc_output = self.ctc_output(encoder_out)
|
||||
|
||||
# For the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
# Compute ctc loss
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
torch.arange(len(x_lens), device="cpu"),
|
||||
torch.zeros_like(x_lens, device="cpu"),
|
||||
torch.clone(x_lens).detach().cpu(),
|
||||
),
|
||||
dim=1,
|
||||
).to(torch.int32)
|
||||
# We need to sort supervision_segments in decreasing order of num_frames
|
||||
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
||||
supervision_segments = supervision_segments[indices]
|
||||
|
||||
# Works with a BPE model
|
||||
decoding_graph = k2.ctc_graph(y, modified=False, device=x.device)
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
ctc_output,
|
||||
supervision_segments,
|
||||
allow_truncate=subsampling_factor - 1,
|
||||
)
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=beam_size,
|
||||
reduction="none",
|
||||
use_double_scores=use_double_scores,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss, ctc_loss)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
reduction: str = "sum",
|
||||
beam_size: int = 10,
|
||||
use_double_scores: bool = False,
|
||||
subsampling_factor: int = 1,
|
||||
return_masks: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor of shape (N*num_channels, S). It contains the labels
|
||||
of the N utterances. The labels are in the range [0, vocab_size). All
|
||||
the channels are concatenated together one after another.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
reduction:
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
beam_size:
|
||||
The beam size used in CTC decoding.
|
||||
use_double_scores:
|
||||
If True, use double precision for CTC decoding.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model. It is used to compute the
|
||||
supervision segments for CTC loss.
|
||||
return_masks:
|
||||
If True, return the masks as well as masked features.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0), (x.size(), x_lens.size())
|
||||
|
||||
# Apply the mask encoder
|
||||
B, T, F = x.shape
|
||||
processed = self.mask_encoder(x) # B,T,F*num_channels
|
||||
masks = processed.view(B, T, F, self.num_channels).unbind(dim=-1)
|
||||
x_masked = [x * m for m in masks]
|
||||
|
||||
# Recognition
|
||||
# Stack the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = torch.cat([x_lens for _ in range(self.num_channels)], dim=0)
|
||||
|
||||
simple_loss, pruned_loss, ctc_loss = self.forward_helper(
|
||||
h,
|
||||
h_lens,
|
||||
y,
|
||||
prune_range,
|
||||
am_scale,
|
||||
lm_scale,
|
||||
reduction=reduction,
|
||||
beam_size=beam_size,
|
||||
use_double_scores=use_double_scores,
|
||||
subsampling_factor=subsampling_factor,
|
||||
)
|
||||
|
||||
# Chunks the outputs into 2 parts along batch axis and then stack them along a new axis.
|
||||
simple_loss = torch.stack(
|
||||
torch.chunk(simple_loss, self.num_channels, dim=0), dim=0
|
||||
)
|
||||
pruned_loss = torch.stack(
|
||||
torch.chunk(pruned_loss, self.num_channels, dim=0), dim=0
|
||||
)
|
||||
ctc_loss = torch.stack(torch.chunk(ctc_loss, self.num_channels, dim=0), dim=0)
|
||||
|
||||
if return_masks:
|
||||
return (simple_loss, pruned_loss, ctc_loss, x_masked, masks)
|
||||
else:
|
||||
return (simple_loss, pruned_loss, ctc_loss, x_masked)
|
1
egs/libricss/SURT/dprnn_zipformer/optim.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
|
1
egs/libricss/SURT/dprnn_zipformer/scaling.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
|
1
egs/libricss/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
|
1452
egs/libricss/SURT/dprnn_zipformer/train.py
Executable file
1452
egs/libricss/SURT/dprnn_zipformer/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1343
egs/libricss/SURT/dprnn_zipformer/train_adapt.py
Executable file
1343
egs/libricss/SURT/dprnn_zipformer/train_adapt.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/libricss/SURT/dprnn_zipformer/zipformer.py
Symbolic link
1
egs/libricss/SURT/dprnn_zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
|
BIN
egs/libricss/SURT/heat.png
Normal file
BIN
egs/libricss/SURT/heat.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 298 KiB |
85
egs/libricss/SURT/local/add_source_feats.py
Executable file
85
egs/libricss/SURT/local/add_source_feats.py
Executable file
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file adds source features as temporal arrays to the mixture manifests.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def add_source_feats(num_jobs=1):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
for type_affix in ["full", "ov40"]:
|
||||
logging.info(f"Adding source features for {type_affix}")
|
||||
mixed_name_clean = f"train_clean_{type_affix}"
|
||||
mixed_name_rvb = f"train_rvb_{type_affix}"
|
||||
|
||||
logging.info("Reading mixed cuts")
|
||||
mixed_cuts_clean = load_manifest_lazy(
|
||||
src_dir / f"cuts_{mixed_name_clean}.jsonl.gz"
|
||||
)
|
||||
mixed_cuts_rvb = load_manifest_lazy(src_dir / f"cuts_{mixed_name_rvb}.jsonl.gz")
|
||||
|
||||
logging.info("Reading source cuts")
|
||||
source_cuts = load_manifest(src_dir / "librispeech_cuts_train_trimmed.jsonl.gz")
|
||||
|
||||
logging.info("Adding source features to the mixed cuts")
|
||||
with tqdm() as pbar, CutSet.open_writer(
|
||||
src_dir / f"cuts_{mixed_name_clean}_sources.jsonl.gz"
|
||||
) as cut_writer_clean, CutSet.open_writer(
|
||||
src_dir / f"cuts_{mixed_name_rvb}_sources.jsonl.gz"
|
||||
) as cut_writer_rvb, LilcomChunkyWriter(
|
||||
output_dir / f"feats_train_{type_affix}_sources"
|
||||
) as source_feat_writer:
|
||||
for cut_clean, cut_rvb in zip(mixed_cuts_clean, mixed_cuts_rvb):
|
||||
assert cut_rvb.id == cut_clean.id + "_rvb"
|
||||
# Create source_feats and source_feat_offsets
|
||||
# (See `lhotse.datasets.K2SurtDataset` for details)
|
||||
source_feats = []
|
||||
source_feat_offsets = []
|
||||
cur_offset = 0
|
||||
for sup in sorted(
|
||||
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
|
||||
):
|
||||
source_cut = source_cuts[sup.id]
|
||||
source_feats.append(source_cut.load_features())
|
||||
source_feat_offsets.append(cur_offset)
|
||||
cur_offset += source_cut.num_frames
|
||||
cut_clean.source_feats = source_feat_writer.store_array(
|
||||
cut_clean.id, np.concatenate(source_feats, axis=0)
|
||||
)
|
||||
cut_clean.source_feat_offsets = source_feat_offsets
|
||||
cut_writer_clean.write(cut_clean)
|
||||
cut_rvb.source_feats = cut_clean.source_feats
|
||||
cut_rvb.source_feat_offsets = cut_clean.source_feat_offsets
|
||||
cut_writer_rvb.write(cut_rvb)
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
add_source_feats()
|
105
egs/libricss/SURT/local/compute_fbank_libricss.py
Executable file
105
egs/libricss/SURT/local/compute_fbank_libricss.py
Executable file
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the LibriCSS dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import pyloudnorm as pyln
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import LilcomChunkyWriter, load_manifest_lazy
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_libricss():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
cuts_ihm_mix = load_manifest_lazy(
|
||||
src_dir / "libricss-ihm-mix_segments_all.jsonl.gz"
|
||||
)
|
||||
cuts_sdm = load_manifest_lazy(src_dir / "libricss-sdm_segments_all.jsonl.gz")
|
||||
|
||||
for name, cuts in [("ihm-mix", cuts_ihm_mix), ("sdm", cuts_sdm)]:
|
||||
dev_cuts = cuts.filter(lambda c: "session0" in c.id)
|
||||
test_cuts = cuts.filter(lambda c: "session0" not in c.id)
|
||||
|
||||
# If SDM cuts, apply loudness normalization
|
||||
if name == "sdm":
|
||||
dev_cuts = dev_cuts.normalize_loudness(target=-23.0)
|
||||
test_cuts = test_cuts.normalize_loudness(target=-23.0)
|
||||
|
||||
logging.info(f"Extracting fbank features for {name} dev cuts")
|
||||
_ = dev_cuts.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"libricss-{name}_feats_dev",
|
||||
manifest_path=src_dir / f"cuts_dev_libricss-{name}.jsonl.gz",
|
||||
batch_duration=500,
|
||||
num_workers=2,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
logging.info(f"Extracting fbank features for {name} test cuts")
|
||||
_ = test_cuts.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"libricss-{name}_feats_test",
|
||||
manifest_path=src_dir / f"cuts_test_libricss-{name}.jsonl.gz",
|
||||
batch_duration=2000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_libricss()
|
111
egs/libricss/SURT/local/compute_fbank_librispeech.py
Executable file
111
egs/libricss/SURT/local/compute_fbank_librispeech.py
Executable file
@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the LibriSpeech dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_librispeech():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_mel_bins = 80
|
||||
|
||||
dataset_parts = (
|
||||
"train-clean-100",
|
||||
"train-clean-360",
|
||||
"train-other-500",
|
||||
)
|
||||
prefix = "librispeech"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=16000),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
for partition, m in manifests.items():
|
||||
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
if (output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
|
||||
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
|
||||
cut_set = cut_set.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
|
||||
manifest_path=f"{src_dir}/{cuts_filename}",
|
||||
batch_duration=4000,
|
||||
num_workers=2,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_librispeech()
|
188
egs/libricss/SURT/local/compute_fbank_lsmix.py
Executable file
188
egs/libricss/SURT/local/compute_fbank_lsmix.py
Executable file
@ -0,0 +1,188 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the synthetically mixed LibriSpeech
|
||||
train and dev sets.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import LilcomChunkyWriter, load_manifest
|
||||
from lhotse.cut import MixedCut, MixTrack, MultiCut
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from lhotse.utils import fix_random_seed, uuid4
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_lsmix():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=["train_clean_full", "train_clean_ov40"],
|
||||
types=["cuts"],
|
||||
output_dir=src_dir,
|
||||
prefix="lsmix",
|
||||
suffix="jsonl.gz",
|
||||
lazy=True,
|
||||
)
|
||||
|
||||
cs = {}
|
||||
cs["clean_full"] = manifests["train_clean_full"]["cuts"]
|
||||
cs["clean_ov40"] = manifests["train_clean_ov40"]["cuts"]
|
||||
|
||||
# only uses RIRs and noises from REVERB challenge
|
||||
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
|
||||
# Apply perturbation to the training cuts
|
||||
logging.info("Applying perturbation to the training cuts")
|
||||
cs["rvb_full"] = cs["clean_full"].map(
|
||||
lambda c: augment(
|
||||
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
|
||||
)
|
||||
)
|
||||
cs["rvb_ov40"] = cs["clean_ov40"].map(
|
||||
lambda c: augment(
|
||||
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
|
||||
)
|
||||
)
|
||||
|
||||
for type_affix in ["full", "ov40"]:
|
||||
for rvb_affix in ["clean", "rvb"]:
|
||||
logging.info(
|
||||
f"Extracting fbank features for {type_affix} {rvb_affix} training cuts"
|
||||
)
|
||||
cuts = cs[f"{rvb_affix}_{type_affix}"]
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
_ = cuts.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir
|
||||
/ f"lsmix_feats_train_{rvb_affix}_{type_affix}",
|
||||
manifest_path=src_dir
|
||||
/ f"cuts_train_{rvb_affix}_{type_affix}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
|
||||
"""
|
||||
Given a mixed cut, this function optionally applies the following augmentations:
|
||||
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
|
||||
- Reverberation using a randomly selected RIR
|
||||
- Adding noise
|
||||
- Perturbing the loudness (in range [-20, -25] dB)
|
||||
"""
|
||||
out_cut = cut.drop_features()
|
||||
|
||||
# Perturb the SNRs (optional)
|
||||
if perturb_snr:
|
||||
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
|
||||
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
|
||||
if i == 0:
|
||||
# Skip the first track since it is the reference
|
||||
continue
|
||||
track.snr = snr
|
||||
|
||||
# Reverberate the cut (optional)
|
||||
if rirs is not None:
|
||||
# Select an RIR at random
|
||||
rir = random.choice(rirs)
|
||||
# Select a channel at random
|
||||
rir_channel = random.choice(list(range(rir.num_channels)))
|
||||
# Reverberate the cut
|
||||
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
|
||||
|
||||
# Add noise (optional)
|
||||
if noises is not None:
|
||||
# Select a noise recording at random
|
||||
noise = random.choice(noises).to_cut()
|
||||
if isinstance(noise, MultiCut):
|
||||
noise = noise.to_mono()[0]
|
||||
# Select an SNR at random
|
||||
snr = random.uniform(10, 30)
|
||||
# Repeat the noise to match the duration of the cut
|
||||
noise = repeat_cut(noise, out_cut.duration)
|
||||
out_cut = MixedCut(
|
||||
id=out_cut.id,
|
||||
tracks=[
|
||||
MixTrack(cut=out_cut, type="MixedCut"),
|
||||
MixTrack(cut=noise, type="DataCut", snr=snr),
|
||||
],
|
||||
)
|
||||
|
||||
# Perturb the loudness (optional)
|
||||
if perturb_loudness:
|
||||
target_loudness = random.uniform(-20, -25)
|
||||
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
|
||||
return out_cut
|
||||
|
||||
|
||||
def repeat_cut(cut, duration):
|
||||
while cut.duration < duration:
|
||||
cut = cut.mix(cut, offset_other_by=cut.duration)
|
||||
return cut.truncate(duration=duration)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
fix_random_seed(42)
|
||||
compute_fbank_lsmix()
|
114
egs/libricss/SURT/local/compute_fbank_musan.py
Executable file
114
egs/libricss/SURT/local/compute_fbank_musan.py
Executable file
@ -0,0 +1,114 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the musan dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, LilcomChunkyWriter, combine
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_musan():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
dataset_parts = (
|
||||
"music",
|
||||
"speech",
|
||||
"noise",
|
||||
)
|
||||
prefix = "musan"
|
||||
suffix = "jsonl.gz"
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts,
|
||||
output_dir=src_dir,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
assert len(manifests) == len(dataset_parts), (
|
||||
len(manifests),
|
||||
len(dataset_parts),
|
||||
list(manifests.keys()),
|
||||
dataset_parts,
|
||||
)
|
||||
|
||||
musan_cuts_path = src_dir / "musan_cuts.jsonl.gz"
|
||||
|
||||
if musan_cuts_path.is_file():
|
||||
logging.info(f"{musan_cuts_path} already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info("Extracting features for Musan")
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
# create chunks of Musan with duration 5 - 10 seconds
|
||||
_ = (
|
||||
CutSet.from_manifests(
|
||||
recordings=combine(part["recordings"] for part in manifests.values())
|
||||
)
|
||||
.cut_into_windows(10.0)
|
||||
.filter(lambda c: c.duration > 5)
|
||||
.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "musan_feats",
|
||||
manifest_path=musan_cuts_path,
|
||||
batch_duration=500,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_musan()
|
204
egs/libricss/SURT/prepare.sh
Executable file
204
egs/libricss/SURT/prepare.sh
Executable file
@ -0,0 +1,204 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/librispeech
|
||||
# You can find audio and transcripts for LibriSpeech in this path.
|
||||
#
|
||||
# - $dl_dir/libricss
|
||||
# You can find audio and transcripts for LibriCSS in this path.
|
||||
#
|
||||
# - $dl_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
#
|
||||
# - $dl_dir/rirs_noises
|
||||
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
|
||||
#
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
vocab_size=500
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/librispeech,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/librispeech $dl_dir/librispeech
|
||||
#
|
||||
if [ ! -d $dl_dir/librispeech ]; then
|
||||
lhotse download librispeech $dl_dir/librispeech
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/libricss,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/libricss $dl_dir/libricss
|
||||
#
|
||||
if [ ! -d $dl_dir/libricss ]; then
|
||||
lhotse download libricss $dl_dir/libricss
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/rirs_noises,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/rirs_noises $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/rirs_noises ]; then
|
||||
lhotse download rirs_noises $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare LibriSpeech manifests"
|
||||
# We assume that you have downloaded the LibriSpeech corpus
|
||||
# to $dl_dir/librispeech. We perform text normalization for the transcripts.
|
||||
# NOTE: Alignments are required for this recipe.
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \
|
||||
-j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare LibriCSS manifests"
|
||||
# We assume that you have downloaded the LibriCSS corpus
|
||||
# to $dl_dir/libricss. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
for mic in sdm ihm-mix; do
|
||||
lhotse prepare libricss --type $mic --segmented $dl_dir/libricss data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare musan manifest and RIRs"
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to $dl_dir/musan
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
|
||||
# We assume that you have downloaded the RIRS_NOISES corpus
|
||||
# to $dl_dir/rirs_noises
|
||||
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Extract features for LibriSpeech, trim to alignments, and shuffle the cuts"
|
||||
python local/compute_fbank_librispeech.py
|
||||
lhotse combine data/manifests/librispeech_cuts_train* - |\
|
||||
lhotse cut trim-to-alignments --type word --max-pause 0.2 - - |\
|
||||
shuf | gzip -c > data/manifests/librispeech_cuts_train_trimmed.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Create simulated mixtures from LibriSpeech (train and dev). This may take a while."
|
||||
# We create a high overlap set which will be used during the model warmup phase, and a
|
||||
# full training set that will be used for the subsequent training.
|
||||
|
||||
gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
|
||||
grep -v "0L" | grep -v "OV10" |\
|
||||
gzip -c > data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz
|
||||
|
||||
gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
|
||||
grep "OV40" |\
|
||||
gzip -c > data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz
|
||||
|
||||
# Warmup mixtures (100k) based on high overlap (OV40)
|
||||
log "Generating 100k anechoic train mixtures for warmup"
|
||||
lhotse workflows simulate-meetings \
|
||||
--method conversational \
|
||||
--fit-to-supervisions data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz \
|
||||
--num-meetings 100000 \
|
||||
--num-speakers-per-meeting 2,3 \
|
||||
--max-duration-per-speaker 15.0 \
|
||||
--max-utterances-per-speaker 3 \
|
||||
--seed 1234 \
|
||||
--num-jobs 4 \
|
||||
data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
|
||||
data/manifests/lsmix_cuts_train_clean_ov40.jsonl.gz
|
||||
|
||||
# Full training set (2,3 speakers) anechoic
|
||||
log "Generating anechoic ${part} set (full)"
|
||||
lhotse workflows simulate-meetings \
|
||||
--method conversational \
|
||||
--fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz \
|
||||
--num-repeats 1 \
|
||||
--num-speakers-per-meeting 2,3 \
|
||||
--max-duration-per-speaker 15.0 \
|
||||
--max-utterances-per-speaker 3 \
|
||||
--seed 1234 \
|
||||
--num-jobs 4 \
|
||||
data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
|
||||
data/manifests/lsmix_cuts_train_clean_full.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Compute fbank features for musan"
|
||||
mkdir -p data/fbank
|
||||
python local/compute_fbank_musan.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Compute fbank features for simulated Libri-mix"
|
||||
mkdir -p data/fbank
|
||||
python local/compute_fbank_lsmix.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Add source feats to mixtures (useful for auxiliary tasks)"
|
||||
python local/add_source_feats.py
|
||||
|
||||
log "Combining lsmix-clean and lsmix-rvb"
|
||||
for type in full ov40; do
|
||||
cat <(gunzip -c data/manifests/cuts_train_clean_${type}_sources.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_rvb_${type}_sources.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_comb_${type}_sources.jsonl.gz
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Compute fbank features for LibriCSS"
|
||||
mkdir -p data/fbank
|
||||
python local/compute_fbank_libricss.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Download LibriSpeech BPE model from HuggingFace."
|
||||
mkdir -p data/lang_bpe_500
|
||||
pushd data/lang_bpe_500
|
||||
wget https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/resolve/main/data/lang_bpe_500/bpe.model
|
||||
popd
|
||||
fi
|
1
egs/libricss/SURT/shared
Symbolic link
1
egs/libricss/SURT/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
BIN
egs/libricss/SURT/surt.png
Normal file
BIN
egs/libricss/SURT/surt.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 112 KiB |
@ -81,20 +81,20 @@ class FrameReducer(nn.Module):
|
||||
fake_limit_indexes = torch.topk(
|
||||
ctc_output[:, :, blank_id], max_limit_len
|
||||
).indices
|
||||
T = (
|
||||
T_arange = (
|
||||
torch.arange(max_limit_len)
|
||||
.expand_as(
|
||||
fake_limit_indexes,
|
||||
)
|
||||
.to(device=x.device)
|
||||
)
|
||||
T = torch.remainder(T, limit_lens.unsqueeze(1))
|
||||
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
|
||||
T_arange = torch.remainder(T_arange, limit_lens.unsqueeze(1))
|
||||
limit_indexes = torch.gather(fake_limit_indexes, 1, T_arange)
|
||||
limit_mask = torch.full_like(
|
||||
non_blank_mask,
|
||||
False,
|
||||
0,
|
||||
device=x.device,
|
||||
).scatter_(1, limit_indexes, True)
|
||||
).scatter_(1, limit_indexes, 1)
|
||||
|
||||
non_blank_mask = non_blank_mask | ~limit_mask
|
||||
|
||||
@ -108,9 +108,9 @@ class FrameReducer(nn.Module):
|
||||
)
|
||||
- out_lens
|
||||
)
|
||||
max_pad_len = pad_lens_list.max()
|
||||
max_pad_len = int(pad_lens_list.max())
|
||||
|
||||
out = F.pad(x, (0, 0, 0, max_pad_len))
|
||||
out = F.pad(x, [0, 0, 0, max_pad_len])
|
||||
|
||||
valid_pad_mask = ~make_pad_mask(pad_lens_list)
|
||||
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
|
||||
|
@ -856,6 +856,10 @@ def main():
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
model.encoder.__class__.non_streaming_forward = model.encoder.__class__.forward
|
||||
model.encoder.__class__.non_streaming_forward = torch.jit.export(
|
||||
model.encoder.__class__.non_streaming_forward
|
||||
)
|
||||
model.encoder.__class__.forward = model.encoder.__class__.streaming_forward
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
|
@ -252,7 +252,7 @@ def main():
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
encoder_out, encoder_out_lens = model.encoder.non_streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
@ -264,7 +264,7 @@ def main():
|
||||
params.update(vars(args))
|
||||
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.vocab_size = num_tokens(token_table)
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
|
@ -25,6 +25,11 @@ import math
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
||||
max_value = torch.max(x, y)
|
||||
diff = torch.abs(x - y)
|
||||
return max_value + torch.log1p(torch.exp(-diff))
|
||||
|
||||
|
||||
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
||||
# 14 is not supported. Please feel free to request support or submit
|
||||
@ -33,10 +38,22 @@ from torch import Tensor
|
||||
# The following function is to solve the above error when exporting
|
||||
# models to ONNX via torch.jit.trace()
|
||||
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
# Caution(fangjun): Put torch.jit.is_scripting() before
|
||||
# torch.onnx.is_in_onnx_export();
|
||||
# otherwise, it will cause errors for torch.jit.script().
|
||||
#
|
||||
# torch.logaddexp() works for both torch.jit.script() and
|
||||
# torch.jit.trace() but it causes errors for ONNX export.
|
||||
#
|
||||
if torch.jit.is_scripting():
|
||||
# Note: We cannot use torch.jit.is_tracing() here as it also
|
||||
# matches torch.onnx.export().
|
||||
return torch.logaddexp(x, y)
|
||||
elif torch.onnx.is_in_onnx_export():
|
||||
return logaddexp_onnx(x, y)
|
||||
else:
|
||||
return (x.exp() + y.exp()).log()
|
||||
# for torch.jit.trace()
|
||||
return torch.logaddexp(x, y)
|
||||
|
||||
class PiecewiseLinear(object):
|
||||
"""
|
||||
@ -1334,6 +1351,13 @@ class SwooshL(torch.nn.Module):
|
||||
return k2.swoosh_l(x)
|
||||
# return SwooshLFunction.apply(x)
|
||||
|
||||
class SwooshLOnnx(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-L activation.
|
||||
"""
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
|
||||
|
||||
class SwooshRFunction(torch.autograd.Function):
|
||||
"""
|
||||
@ -1400,6 +1424,13 @@ class SwooshR(torch.nn.Module):
|
||||
return k2.swoosh_r(x)
|
||||
# return SwooshRFunction.apply(x)
|
||||
|
||||
class SwooshROnnx(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-R activation.
|
||||
"""
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
|
||||
|
||||
# simple version of SwooshL that does not redefine the backprop, used in
|
||||
# ActivationDropoutAndLinearFunction.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user