From 3b257dd5ae79bff99470ec1cbbeaa8fae84f956a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 25 Jul 2024 16:46:24 +0800 Subject: [PATCH 01/59] Add docker images for torch 2.4 (#1704) --- .../scripts/docker/generate_build_matrix.py | 6 +- .github/workflows/build-docker-image.yml | 34 ++++++++- .github/workflows/run-docker-image.yml | 34 ++++++++- docker/torch2.0.0-cuda11.7.dockerfile | 2 +- docker/torch2.1.0-cuda11.8.dockerfile | 2 +- docker/torch2.1.0-cuda12.1.dockerfile | 2 +- docker/torch2.2.0-cuda11.8.dockerfile | 2 +- docker/torch2.2.0-cuda12.1.dockerfile | 2 +- docker/torch2.2.1-cuda11.8.dockerfile | 2 +- docker/torch2.2.1-cuda12.1.dockerfile | 2 +- docker/torch2.2.2-cuda11.8.dockerfile | 2 +- docker/torch2.2.2-cuda12.1.dockerfile | 2 +- docker/torch2.3.1-cuda11.8.dockerfile | 2 +- docker/torch2.3.1-cuda12.1.dockerfile | 2 +- docker/torch2.4.0-cuda11.8.dockerfile | 73 +++++++++++++++++++ docker/torch2.4.0-cuda12.1.dockerfile | 73 +++++++++++++++++++ docker/torch2.4.0-cuda12.4.dockerfile | 73 +++++++++++++++++++ 17 files changed, 301 insertions(+), 14 deletions(-) create mode 100644 docker/torch2.4.0-cuda11.8.dockerfile create mode 100644 docker/torch2.4.0-cuda12.1.dockerfile create mode 100644 docker/torch2.4.0-cuda12.4.dockerfile diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 7f13c59bd..5a763e044 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20240223" kaldifeat_version = "1.25.4.dev20240223" - version = "20240606" + version = "20240725" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] torch_version += ["1.13.0", "1.13.1"] @@ -53,6 +53,7 @@ def get_matrix(): torch_version += ["2.1.0", "2.1.1", "2.1.2"] torch_version += ["2.2.0", "2.2.1", "2.2.2"] torch_version += ["2.3.0", "2.3.1"] + torch_version += ["2.4.0"] matrix = [] for p in python_version: @@ -78,6 +79,9 @@ def get_matrix(): elif t == "2.3.1": k2_version_2 = "1.24.4.dev20240606" kaldifeat_version_2 = "1.25.4.dev20240606" + elif t == "2.4.0": + k2_version_2 = "1.24.4.dev20240725" + kaldifeat_version_2 = "1.25.4.dev20240725" matrix.append( { diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index 23dcb519f..77480bd3e 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout @@ -37,6 +37,38 @@ jobs: rm -rf /opt/hostedtoolcache df -h + - name: Free more space + shell: bash + run: | + # https://github.com/orgs/community/discussions/25678 + cd /opt + find . -maxdepth 1 -mindepth 1 '!' -path ./containerd '!' -path ./actionarchivecache '!' -path ./runner '!' -path ./runner-cache -exec rm -rf '{}' ';' + + sudo rm -rf /usr/share/dotnet + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + # this might remove tools that are actually needed, + # if set to "true" but frees about 6 GB + tool-cache: false + + # all of these default to true, but feel free to set to + # "false" if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: false + swap-storage: true + + - name: Check space + shell: bash + run: | + df -h + - name: Log in to Docker Hub uses: docker/login-action@v2 with: diff --git a/.github/workflows/run-docker-image.yml b/.github/workflows/run-docker-image.yml index 336d930ca..05c630ad5 100644 --- a/.github/workflows/run-docker-image.yml +++ b/.github/workflows/run-docker-image.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout - uses: actions/checkout@v2 @@ -28,6 +28,38 @@ jobs: rm -rf /opt/hostedtoolcache df -h + - name: Free more space + shell: bash + run: | + # https://github.com/orgs/community/discussions/25678 + cd /opt + find . -maxdepth 1 -mindepth 1 '!' -path ./containerd '!' -path ./actionarchivecache '!' -path ./runner '!' -path ./runner-cache -exec rm -rf '{}' ';' + + sudo rm -rf /usr/share/dotnet + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + # this might remove tools that are actually needed, + # if set to "true" but frees about 6 GB + tool-cache: false + + # all of these default to true, but feel free to set to + # "false" if necessary for your workflow + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: false + swap-storage: true + + - name: Check space + shell: bash + run: | + df -h + - name: Run the build process with Docker uses: addnab/docker-run-action@v3 with: diff --git a/docker/torch2.0.0-cuda11.7.dockerfile b/docker/torch2.0.0-cuda11.7.dockerfile index e2e27b55d..22f0a7a95 100644 --- a/docker/torch2.0.0-cuda11.7.dockerfile +++ b/docker/torch2.0.0-cuda11.7.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.1.0-cuda11.8.dockerfile b/docker/torch2.1.0-cuda11.8.dockerfile index de1e07e69..e87e99468 100644 --- a/docker/torch2.1.0-cuda11.8.dockerfile +++ b/docker/torch2.1.0-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.1.0-cuda12.1.dockerfile b/docker/torch2.1.0-cuda12.1.dockerfile index 89303797a..b2628ef9c 100644 --- a/docker/torch2.1.0-cuda12.1.dockerfile +++ b/docker/torch2.1.0-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.0-cuda11.8.dockerfile b/docker/torch2.2.0-cuda11.8.dockerfile index 3364477a8..0f65f9595 100644 --- a/docker/torch2.2.0-cuda11.8.dockerfile +++ b/docker/torch2.2.0-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.0-cuda12.1.dockerfile b/docker/torch2.2.0-cuda12.1.dockerfile index 3cc41902d..7a544c0b2 100644 --- a/docker/torch2.2.0-cuda12.1.dockerfile +++ b/docker/torch2.2.0-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.1-cuda11.8.dockerfile b/docker/torch2.2.1-cuda11.8.dockerfile index 76b785622..0c04314a7 100644 --- a/docker/torch2.2.1-cuda11.8.dockerfile +++ b/docker/torch2.2.1-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.1-cuda12.1.dockerfile b/docker/torch2.2.1-cuda12.1.dockerfile index 55bdfa4d7..5c4c9a99a 100644 --- a/docker/torch2.2.1-cuda12.1.dockerfile +++ b/docker/torch2.2.1-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.2-cuda11.8.dockerfile b/docker/torch2.2.2-cuda11.8.dockerfile index 02de82c50..d712dd57a 100644 --- a/docker/torch2.2.2-cuda11.8.dockerfile +++ b/docker/torch2.2.2-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.2.2-cuda12.1.dockerfile b/docker/torch2.2.2-cuda12.1.dockerfile index 44ad38b8e..af0e966e7 100644 --- a/docker/torch2.2.2-cuda12.1.dockerfile +++ b/docker/torch2.2.2-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.3.1-cuda11.8.dockerfile b/docker/torch2.3.1-cuda11.8.dockerfile index 545b42e9f..ee07a4c24 100644 --- a/docker/torch2.3.1-cuda11.8.dockerfile +++ b/docker/torch2.3.1-cuda11.8.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.3.1-cuda12.1.dockerfile b/docker/torch2.3.1-cuda12.1.dockerfile index ca13752e4..f5bac35a2 100644 --- a/docker/torch2.3.1-cuda12.1.dockerfile +++ b/docker/torch2.3.1-cuda12.1.dockerfile @@ -41,7 +41,7 @@ RUN apt-get update && \ # Install dependencies RUN pip install --no-cache-dir \ - torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ diff --git a/docker/torch2.4.0-cuda11.8.dockerfile b/docker/torch2.4.0-cuda11.8.dockerfile new file mode 100644 index 000000000..a5ffc0bb5 --- /dev/null +++ b/docker/torch2.4.0-cuda11.8.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240725+cuda11.8.torch2.4.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240725+cuda11.8.torch2.4.0" +ARG TORCHAUDIO_VERSION="2.4.0+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.4.0-cuda12.1.dockerfile b/docker/torch2.4.0-cuda12.1.dockerfile new file mode 100644 index 000000000..01208ce2d --- /dev/null +++ b/docker/torch2.4.0-cuda12.1.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240725+cuda12.1.torch2.4.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240725+cuda12.1.torch2.4.0" +ARG TORCHAUDIO_VERSION="2.4.0+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.4.0-cuda12.4.dockerfile b/docker/torch2.4.0-cuda12.4.dockerfile new file mode 100644 index 000000000..d0d300cfa --- /dev/null +++ b/docker/torch2.4.0-cuda12.4.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240725+cuda12.4.torch2.4.0" +ARG KALDIFEAT_VERSION="1.25.4.dev20240725+cuda12.4.torch2.4.0" +ARG TORCHAUDIO_VERSION="2.4.0+cu124" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall From 1730fce688aa4cb6c3162ed860e29c6a72da1604 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Tue, 13 Aug 2024 17:02:14 +0200 Subject: [PATCH 02/59] split `save_results()` -> `save_asr_output()` + `save_wer_results()` (#1712) - the idea is to support `--skip-scoring` argument passed to a decoding script - created for Transducer decoding (non-streaming, streaming) - it can be done also for CTC decoding... (not yet) - also added `--label` for extra label in `streaming_decode.py` - and also added `set_caching_enabled(True)`, which has no effect on librispeech, but it leads to faster runtime on DBs with long recordings (assuming `librispeech/zipformer` scripts are the example scripts for other setups) --- egs/librispeech/ASR/zipformer/ctc_decode.py | 96 +++++++++---- egs/librispeech/ASR/zipformer/decode.py | 136 +++++++++++------- .../ASR/zipformer/streaming_decode.py | 88 +++++++++--- 3 files changed, 217 insertions(+), 103 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 435a79e7f..9db429959 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -120,6 +120,7 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( @@ -296,6 +297,13 @@ def get_parser(): """, ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + add_model_arguments(parser) return parser @@ -455,7 +463,7 @@ def decode_one_batch( # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] hyps = [s.split() for s in hyps] key = "ctc-decoding" - return {key: hyps} + return {key: hyps} # note: returns words if params.decoding_method == "attention-decoder-rescoring-no-ngram": best_path_dict = rescore_with_attention_decoder_no_ngram( @@ -492,7 +500,7 @@ def decode_one_batch( ) hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa return {key: hyps} if params.decoding_method in ["1best", "nbest"]: @@ -500,7 +508,7 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - key = "no_rescore" + key = "no-rescore" else: best_path = nbest_decoding( lattice=lattice, @@ -508,11 +516,11 @@ def decode_one_batch( use_double_scores=params.use_double_scores, nbest_scale=params.nbest_scale, ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] - return {key: hyps} + return {key: hyps} # note: returns BPE tokens assert params.decoding_method in [ "nbest-rescoring", @@ -646,7 +654,27 @@ def decode_dataset( return results -def save_results( +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + ) + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], @@ -661,32 +689,30 @@ def save_results( test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) test_set_wers[key] = wer - if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -705,6 +731,9 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( "ctc-greedy-search", "ctc-decoding", @@ -719,9 +748,9 @@ def main(): params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -730,11 +759,11 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -940,12 +969,19 @@ def main(): G=G, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index df2d555a0..cbfb3728e 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -121,6 +121,7 @@ from beam_search import ( modified_beam_search_lm_shallow_fusion, modified_beam_search_LODR, ) +from lhotse import set_caching_enabled from train import add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm @@ -369,6 +370,14 @@ def get_parser(): modified_beam_search_LODR. """, ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + add_model_arguments(parser) return parser @@ -590,21 +599,23 @@ def decode_one_batch( ) hyps.append(sp.decode(hyp).split()) + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" - return {key: hyps} + return {prefix: hyps} elif "modified_beam_search" in params.decoding_method: - prefix = f"beam_size_{params.beam_size}" + prefix += f"_beam-size-{params.beam_size}" if params.decoding_method in ( "modified_beam_search_lm_rescore", "modified_beam_search_lm_rescore_LODR", @@ -617,10 +628,11 @@ def decode_one_batch( return ans else: if params.has_contexts: - prefix += f"-context-score-{params.context_score}" + prefix += f"_context-score-{params.context_score}" return {prefix: hyps} else: - return {f"beam_size_{params.beam_size}": hyps} + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} def decode_dataset( @@ -707,46 +719,58 @@ def decode_dataset( return results -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ test_set_wers = dict() for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -762,6 +786,9 @@ def main(): params = get_params() params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + assert params.decoding_method in ( "greedy_search", "beam_search", @@ -783,9 +810,9 @@ def main(): params.has_contexts = False if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix = f"iter-{params.iter}_avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" if params.causal: assert ( @@ -794,20 +821,20 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" if params.decoding_method in ( "modified_beam_search", "modified_beam_search_LODR", @@ -815,19 +842,19 @@ def main(): if params.has_contexts: params.suffix += f"-context-score-{params.context_score}" else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" if params.use_shallow_fusion: - params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" if "LODR" in params.decoding_method: params.suffix += ( - f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" ) if params.use_averaged_model: - params.suffix += "-use-averaged-model" + params.suffix += "_use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -1038,12 +1065,19 @@ def main(): ngram_lm_scale=ngram_lm_scale, ) - save_results( + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index 360523b8e..ebcafbf87 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -43,7 +43,7 @@ import torch from asr_datamodule import LibriSpeechAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet +from lhotse import CutSet, set_caching_enabled from streaming_beam_search import ( fast_beam_search_one_best, greedy_search, @@ -76,6 +76,13 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--label", + type=str, + default="", + help="""Extra label of the decoding run.""", + ) + parser.add_argument( "--epoch", type=int, @@ -188,6 +195,14 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""" + ) + + add_model_arguments(parser) return parser @@ -640,46 +655,60 @@ def decode_dataset( return {key: decode_results} -def save_results( +def save_asr_output( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[List[str], List[str]]]], ): - test_set_wers = dict() + """ + Save text produced by ASR. + """ for key, results in results_dict.items(): - recog_path = ( + recogs_filename = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + store_transcripts(filename=recogs_filename, texts=results) + logging.info(f"The transcripts are stored in {recogs_filename}") + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_filename, "w") as f: + with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True + fd, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info(f"Wrote detailed error stats to {errs_filename}") test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( + + wer_filename = ( params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) + print(f"{key}\t{val}", file=fd) - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) + s += f"{key}\t{val}{note}\n" note = "" logging.info(s) @@ -694,6 +723,9 @@ def main(): params = get_params() params.update(vars(args)) + # enable AudioCache + set_caching_enabled(True) # lhotse + params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: @@ -706,18 +738,21 @@ def main(): assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" # for fast_beam_search if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" if params.use_averaged_model: params.suffix += "-use-averaged-model" + if params.label: + params.suffix += f"-{params.label}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -845,12 +880,21 @@ def main(): decoding_graph=decoding_graph, ) - save_results( + + save_asr_output( params=params, test_set_name=test_set, results_dict=results_dict, ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + logging.info("Done!") From 6ac3343ce53fa6685fca0f876f2c6245af4caac5 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 16 Aug 2024 20:13:02 +0800 Subject: [PATCH 03/59] fix path in README.md (#1722) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 31e514606..81cfc03ce 100644 --- a/README.md +++ b/README.md @@ -375,7 +375,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [libricss]: egs/libricss/SURT [libriheavy]: egs/libriheavy/ASR [mgb2]: egs/mgb2/ASR -[peoplespeech]: egs/peoplespeech/ASR +[peoplespeech]: egs/peoples_speech/ASR [spgispeech]: egs/spgispeech/ASR [voxpopuli]: egs/voxpopuli/ASR [xbmu-amdo31]: egs/xbmu-amdo31/ASR From 595297229405fa74ec0dd53e0e7d0ce051802148 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Sat, 17 Aug 2024 13:24:38 +0800 Subject: [PATCH 04/59] Keep the custom fields in libriheavy manifest (#1719) --- egs/libriheavy/ASR/local/prepare_manifest.py | 10 +++++++--- egs/libriheavy/ASR/prepare.sh | 7 ++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py index 42f392cae..d7e184d86 100755 --- a/egs/libriheavy/ASR/local/prepare_manifest.py +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -29,17 +29,21 @@ def simple_cleanup(text: str) -> str: # Assign text of the supervisions and remove unnecessary entries. def main(): - assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR" + assert ( + len(sys.argv) == 4 + ), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS" fname = Path(sys.argv[1]).name oname = Path(sys.argv[2]) / fname + keep_custom_fields = bool(sys.argv[3]) with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: for line in fin: cut = json.loads(line) cut["supervisions"][0]["text"] = simple_cleanup( cut["supervisions"][0]["custom"]["texts"][0] ) - del cut["supervisions"][0]["custom"] - del cut["custom"] + if not keep_custom_fields: + del cut["supervisions"][0]["custom"] + del cut["custom"] fout.write((json.dumps(cut) + "\n").encode()) diff --git a/egs/libriheavy/ASR/prepare.sh b/egs/libriheavy/ASR/prepare.sh index b0736c98b..366a1459f 100755 --- a/egs/libriheavy/ASR/prepare.sh +++ b/egs/libriheavy/ASR/prepare.sh @@ -29,6 +29,11 @@ export CUDA_VISIBLE_DEVICES="" # - speech dl_dir=$PWD/download +# If you want to do PromptASR experiments, please set it to True +# as this will keep the texts and pre_text information required for +# the training of PromptASR. +keep_custom_fields=False + . shared/parse_options.sh || exit 1 # vocab size for sentence piece models. @@ -134,7 +139,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then for subset in small medium large dev test_clean test_other; do if [ ! -e $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then log "Prepare manifest for subset : ${subset}" - ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir + ./local/prepare_manifest.py $dl_dir/libriheavy/libriheavy_cuts_${subset}.jsonl.gz $manifests_dir $keep_custom_fields fi done fi From 3fc06cc2b9120a79a3e061bf35cef8d7220a42f3 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:27:25 +0800 Subject: [PATCH 05/59] Support AudioSet training with weighted sampler (#1727) --- egs/audioset/AT/RESULTS.md | 36 +++++-- egs/audioset/AT/local/compute_weight.py | 73 ++++++++++++++ egs/audioset/AT/prepare.sh | 13 ++- egs/audioset/AT/zipformer/at_datamodule.py | 107 ++++++++++++++++----- egs/audioset/AT/zipformer/train.py | 11 ++- 5 files changed, 207 insertions(+), 33 deletions(-) create mode 100644 egs/audioset/AT/local/compute_weight.py diff --git a/egs/audioset/AT/RESULTS.md b/egs/audioset/AT/RESULTS.md index 0128b7018..36613db03 100644 --- a/egs/audioset/AT/RESULTS.md +++ b/egs/audioset/AT/RESULTS.md @@ -35,16 +35,40 @@ python zipformer/train.py \ --master-port 13455 ``` +We recommend that you train the model with weighted sampler, as the model converges +faster with better performance: + +| Model | mAP | +| ------ | ------- | +| Zipformer-AT, train with weighted sampler | 46.6 | + The evaluation command is: ```bash -python zipformer/evaluate.py \ - --epoch 32 \ - --avg 8 \ - --exp-dir zipformer/exp_at_as_full \ - --max-duration 500 +export CUDA_VISIBLE_DEVICES="4,5,6,7" +subset=full +weighted_sampler=1 +bucket_sampler=0 +lr_epochs=15 + +python zipformer/train.py \ + --world-size 4 \ + --audioset-subset $subset \ + --num-epochs 120 \ + --start-epoch 1 \ + --use-fp16 1 \ + --num-events 527 \ + --lr-epochs $lr_epochs \ + --exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \ + --weighted-sampler $weighted_sampler \ + --bucketing-sampler $bucket_sampler \ + --max-duration 1000 \ + --enable-musan True \ + --master-port 13452 ``` +The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler + #### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M @@ -92,4 +116,4 @@ python zipformer/evaluate.py \ --encoder-unmasked-dim 192,192,192,192,192,192 \ --exp-dir zipformer/exp_small_at_as_full \ --max-duration 500 -``` \ No newline at end of file +``` diff --git a/egs/audioset/AT/local/compute_weight.py b/egs/audioset/AT/local/compute_weight.py new file mode 100644 index 000000000..a0deddc0c --- /dev/null +++ b/egs/audioset/AT/local/compute_weight.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file generates the manifest and computes the fbank features for AudioSet +dataset. The generated manifests and features are stored in data/fbank. +""" + +import argparse + +import lhotse +from lhotse import load_manifest + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz" + ) + + parser.add_argument( + "--output", + type=str, + required=True, + ) + return parser + + +def main(): + # Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py + parser = get_parser() + args = parser.parse_args() + + cuts = load_manifest(args.input_manifest) + + print(f"A total of {len(cuts)} cuts.") + + label_count = [0] * 527 # a total of 527 classes + for c in cuts: + audio_event = c.supervisions[0].audio_event + labels = list(map(int, audio_event.split(";"))) + for label in labels: + label_count[label] += 1 + + with open(args.output, "w") as f: + for c in cuts: + audio_event = c.supervisions[0].audio_event + labels = list(map(int, audio_event.split(";"))) + weight = 0 + for label in labels: + weight += 1000 / (label_count[label] + 0.01) + f.write(f"{c.id} {weight}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/audioset/AT/prepare.sh b/egs/audioset/AT/prepare.sh index f7f73a008..8beaf2d86 100755 --- a/egs/audioset/AT/prepare.sh +++ b/egs/audioset/AT/prepare.sh @@ -10,6 +10,7 @@ stage=-1 stop_stage=4 dl_dir=$PWD/download +fbank_dir=data/fbank # we assume that you have your downloaded the AudioSet and placed # it under $dl_dir/audioset, the folder structure should look like @@ -49,7 +50,6 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set" - fbank_dir=data/fbank if [! -e $fbank_dir/.balanced.done]; then python local/generate_audioset_manifest.py \ --dataset-dir $dl_dir/audioset \ @@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then touch data/fbank/.musan.done fi fi + +# The following stages are required to do weighted-sampling training +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare for weighted-sampling training" + if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then + lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz + fi + python ./local/compute_weight.py \ + --input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \ + --output $fbank_dir/sampling_weights_full.txt +fi diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py index ac8671fa6..b7df01539 100644 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -31,6 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures PrecomputedFeatures, SimpleCutSampler, SpecAugment, + WeightedSimpleCutSampler, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -99,6 +100,20 @@ class AudioSetATDatamodule: help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) + group.add_argument( + "--weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "It cannot be used together with bucketing sampler", + ) + group.add_argument( + "--num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler", + ) group.add_argument( "--bucketing-sampler", type=str2bool, @@ -295,6 +310,9 @@ class AudioSetATDatamodule: ) if self.args.bucketing_sampler: + assert ( + not self.args.weighted_sampler + ), "weighted sampling is not supported in bucket sampler" logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( cuts_train, @@ -304,13 +322,26 @@ class AudioSetATDatamodule: drop_last=self.args.drop_last, ) else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - drop_last=self.args.drop_last, - ) + if self.args.weighted_sampler: + # assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset" + logging.info("Using weighted SimpleCutSampler") + weights = self.audioset_sampling_weights() + train_sampler = WeightedSimpleCutSampler( + cuts_train, + weights, + num_samples=self.args.num_samples, + max_duration=self.args.max_duration, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + ) logging.info("About to create train dataloader") if sampler_state_dict is not None: @@ -373,11 +404,9 @@ class AudioSetATDatamodule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = AudioTaggingDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)() - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( @@ -397,21 +426,30 @@ class AudioSetATDatamodule: @lru_cache() def audioset_train_cuts(self) -> CutSet: logging.info("About to get the audioset training cuts.") - balanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" - ) - if self.args.audioset_subset == "full": - unbalanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" - ) - cuts = CutSet.mux( - balanced_cuts, - unbalanced_cuts, - weights=[20000, 2000000], - stop_early=True, + if not self.args.weighted_sampler: + balanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" ) + if self.args.audioset_subset == "full": + unbalanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" + ) + cuts = CutSet.mux( + balanced_cuts, + unbalanced_cuts, + weights=[20000, 2000000], + stop_early=True, + ) + else: + cuts = balanced_cuts else: - cuts = balanced_cuts + # assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet" + cuts = load_manifest( + self.args.manifest_dir + / f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz" + ) + logging.info(f"Get {len(cuts)} cuts in total.") + return cuts @lru_cache() @@ -420,3 +458,22 @@ class AudioSetATDatamodule: return load_manifest_lazy( self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz" ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 2d193030a..67c703364 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -789,12 +789,14 @@ def train_one_epoch( rank=0, ) + num_samples = 0 for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) params.batch_idx_train += 1 batch_size = batch["inputs"].size(0) + num_samples += batch_size try: with torch.cuda.amp.autocast(enabled=params.use_fp16): @@ -919,6 +921,12 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) + if num_samples > params.num_samples: + logging.info( + f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch" + ) + break + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -1032,7 +1040,8 @@ def run(rank, world_size, args): return True - train_cuts = train_cuts.filter(remove_short_and_long_utt) + if not params.weighted_sampler: + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint From 3b434fe83c40eaf3c4739c26d11aa3a3b8af8ddc Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 23 Aug 2024 09:33:46 +0800 Subject: [PATCH 06/59] fix triton onnx export (#1730) --- egs/librispeech/ASR/zipformer/export-onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index ed8a0ef0f..ca3cbf0d5 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -218,7 +218,7 @@ class OnnxEncoder(nn.Module): - encoder_out_lens, A 1-D tensor of shape (N,) """ x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) + src_key_padding_mask = make_pad_mask(x_lens, x.shape[1]) x = x.permute(1, 0, 2) encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) From a6c02a4d8c5c5db3f30899ca622a813640aba63f Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Fri, 23 Aug 2024 09:42:22 +0800 Subject: [PATCH 07/59] zipformer BF16 training recipe (#1700) Support Zipformer AMP +BF16 training --- egs/librispeech/ASR/RESULTS.md | 17 +++++++++ egs/librispeech/ASR/zipformer/scaling.py | 12 +++--- egs/librispeech/ASR/zipformer/train.py | 47 ++++++++++++++++++------ 3 files changed, 59 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 66b147764..bc7d8a5ef 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -307,6 +307,23 @@ done To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html). +We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.** + +The amp+bf16 training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 0 \ + --use-bf16 1 \ + --exp-dir zipformer/exp_amp_bf16 \ + --causal 0 \ + --full-libri 1 \ + --max-duration 1000 +``` + ##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M The tensorboard log can be found at diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 164cc7bfd..2a40b8d64 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -297,7 +297,7 @@ class SoftmaxFunction(torch.autograd.Function): # (presumably) that op does not support float16, and autocast # is enabled. if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) + ans = ans.to(torch.get_autocast_gpu_dtype()) ctx.save_for_backward(ans) ctx.x_dtype = x.dtype ctx.dim = dim @@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) s = torch.sigmoid(x - 1.0) @@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1379,7 +1379,7 @@ class SwooshLFunction(torch.autograd.Function): d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod @@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function): def forward(ctx, x: Tensor) -> Tensor: requires_grad = x.requires_grad - if x.dtype == torch.float16: + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: x = x.to(torch.float32) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) @@ -1455,7 +1455,7 @@ class SwooshRFunction(torch.autograd.Function): d_int = d_scaled.to(torch.uint8) ctx.save_for_backward(d_int) if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) + y = y.to(torch.get_autocast_gpu_dtype()) return y @staticmethod diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9b6f4a93a..9c1c7f5a7 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -521,6 +521,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + add_model_arguments(parser) return parser @@ -1027,7 +1034,9 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,9 +1056,7 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except Exception as e: - logging.info( - f"Caught exception: {e}." - ) + logging.info(f"Caught exception: {e}.") save_bad_model() display_and_save_batch(batch, params=params, sp=sp) raise @@ -1090,7 +1097,7 @@ def train_one_epoch( rank=rank, ) - if batch_idx % 100 == 0 and params.use_fp16: + if batch_idx % 100 == 0 and params.use_autocast: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. @@ -1109,14 +1116,14 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: cur_lr = max(scheduler.get_last_lr()) - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") ) if tb_writer is not None: @@ -1128,7 +1135,7 @@ def train_one_epoch( tb_writer, "train/current_", params.batch_idx_train ) tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: + if params.use_autocast: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train ) @@ -1204,9 +1211,25 @@ def run(rank, world_size, args): params.ctc_loss_scale = 1.0 else: assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( - params.ctc_loss_scale, params.attention_decoder_loss_scale + params.ctc_loss_scale, + params.attention_decoder_loss_scale, ) + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + logging.info(params) logging.info("About to create model") @@ -1339,7 +1362,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): loss, _ = compute_loss( params=params, model=model, From cea0dbe7b1cd4d5b7512c7974e53034ef456dd70 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Wed, 28 Aug 2024 12:15:01 +0800 Subject: [PATCH 08/59] fix gigaspeech_prepare.sh (#1734) --- egs/gigaspeech/ASR/prepare.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index 5e54b669a..219197e13 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -161,14 +161,14 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Split XL subset into pieces (may take 30 minutes)" split_dir=data/fbank/XL_split if [ ! -f $split_dir/.split_completed ]; then - lhotse split-lazy ./data/fbank/cuts_XL_raw.jsonl.gz $split_dir $num_per_split + lhotse split-lazy ./data/fbank/gigaspeech_cuts_XL_raw.jsonl.gz $split_dir $num_per_split touch $split_dir/.split_completed fi fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Compute features for XL" - num_splits=$(find data/fbank/XL_split -name "cuts_XL_raw.*.jsonl.gz" | wc -l) + num_splits=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL_raw.*.jsonl.gz" | wc -l) python3 ./local/compute_fbank_gigaspeech_splits.py \ --num-workers 20 \ --batch-duration 600 \ @@ -177,9 +177,9 @@ fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Combine features for XL (may take 3 hours)" - if [ ! -f data/fbank/cuts_XL.jsonl.gz ]; then - pieces=$(find data/fbank/XL_split -name "cuts_XL.*.jsonl.gz") - lhotse combine $pieces data/fbank/cuts_XL.jsonl.gz + if [ ! -f data/fbank/gigaspeech_cuts_XL.jsonl.gz ]; then + pieces=$(find data/fbank/XL_split -name "gigaspeech_cuts_XL.*.jsonl.gz") + lhotse combine $pieces data/fbank/gigaspeech_cuts_XL.jsonl.gz fi fi From f233ffa02ae248e4ad2c526d5c35c4a9ade601f5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 7 Sep 2024 18:17:04 +0800 Subject: [PATCH 09/59] Add docker images for torch 2.4.1 (#1743) --- .../scripts/docker/generate_build_matrix.py | 6 +- .github/workflows/build-docker-image.yml | 2 +- docker/torch2.4.1-cuda11.8.dockerfile | 73 +++++++++++++++++++ docker/torch2.4.1-cuda12.1.dockerfile | 73 +++++++++++++++++++ docker/torch2.4.1-cuda12.4.dockerfile | 73 +++++++++++++++++++ docs/source/docker/intro.rst | 6 ++ 6 files changed, 231 insertions(+), 2 deletions(-) create mode 100644 docker/torch2.4.1-cuda11.8.dockerfile create mode 100644 docker/torch2.4.1-cuda12.1.dockerfile create mode 100644 docker/torch2.4.1-cuda12.4.dockerfile diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 5a763e044..492d3ed47 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,7 +45,7 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20240223" kaldifeat_version = "1.25.4.dev20240223" - version = "20240725" + version = "20240905" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] torch_version += ["1.13.0", "1.13.1"] @@ -54,6 +54,7 @@ def get_matrix(): torch_version += ["2.2.0", "2.2.1", "2.2.2"] torch_version += ["2.3.0", "2.3.1"] torch_version += ["2.4.0"] + torch_version += ["2.4.1"] matrix = [] for p in python_version: @@ -82,6 +83,9 @@ def get_matrix(): elif t == "2.4.0": k2_version_2 = "1.24.4.dev20240725" kaldifeat_version_2 = "1.25.4.dev20240725" + elif t == "2.4.1": + k2_version_2 = "1.24.4.dev20240905" + kaldifeat_version_2 = "1.25.4.dev20240905" matrix.append( { diff --git a/.github/workflows/build-docker-image.yml b/.github/workflows/build-docker-image.yml index 77480bd3e..a473590a3 100644 --- a/.github/workflows/build-docker-image.yml +++ b/.github/workflows/build-docker-image.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - image: ["torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] + image: ["torch2.4.1-cuda12.4", "torch2.4.1-cuda12.1", "torch2.4.1-cuda11.8", "torch2.4.0-cuda12.4", "torch2.4.0-cuda12.1", "torch2.4.0-cuda11.8", "torch2.3.1-cuda12.1", "torch2.3.1-cuda11.8", "torch2.2.2-cuda12.1", "torch2.2.2-cuda11.8", "torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] steps: # refer to https://github.com/actions/checkout diff --git a/docker/torch2.4.1-cuda11.8.dockerfile b/docker/torch2.4.1-cuda11.8.dockerfile new file mode 100644 index 000000000..bc1782b0d --- /dev/null +++ b/docker/torch2.4.1-cuda11.8.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.1-cuda11.8-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240905+cuda11.8.torch2.4.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda11.8.torch2.4.1" +ARG TORCHAUDIO_VERSION="2.4.1+cu118" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.4.1-cuda12.1.dockerfile b/docker/torch2.4.1-cuda12.1.dockerfile new file mode 100644 index 000000000..df2ea61a4 --- /dev/null +++ b/docker/torch2.4.1-cuda12.1.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.1-cuda12.1-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240905+cuda12.1.torch2.4.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda12.1.torch2.4.1" +ARG TORCHAUDIO_VERSION="2.4.1+cu121" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docker/torch2.4.1-cuda12.4.dockerfile b/docker/torch2.4.1-cuda12.4.dockerfile new file mode 100644 index 000000000..4d6da2804 --- /dev/null +++ b/docker/torch2.4.1-cuda12.4.dockerfile @@ -0,0 +1,73 @@ +FROM pytorch/pytorch:2.4.1-cuda12.4-cudnn9-devel +# python 3.10 + +ENV LC_ALL C.UTF-8 + +ARG DEBIAN_FRONTEND=noninteractive + +# python 3.10 +ARG K2_VERSION="1.24.4.dev20240905+cuda12.4.torch2.4.1" +ARG KALDIFEAT_VERSION="1.25.4.dev20240905+cuda12.4.torch2.4.1" +ARG TORCHAUDIO_VERSION="2.4.1+cu124" + +LABEL authors="Fangjun Kuang " +LABEL k2_version=${K2_VERSION} +LABEL kaldifeat_version=${KALDIFEAT_VERSION} +LABEL github_repo="https://github.com/k2-fsa/icefall" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + curl \ + vim \ + libssl-dev \ + autoconf \ + automake \ + bzip2 \ + ca-certificates \ + ffmpeg \ + g++ \ + gfortran \ + git \ + libtool \ + make \ + patch \ + sox \ + subversion \ + unzip \ + valgrind \ + wget \ + zlib1g-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install dependencies +RUN pip install --no-cache-dir \ + torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torchaudio/ \ + k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \ + git+https://github.com/lhotse-speech/lhotse \ + kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \ + kaldi_native_io \ + kaldialign \ + kaldifst \ + kaldilm \ + sentencepiece>=0.1.96 \ + tensorboard \ + typeguard \ + dill \ + onnx \ + onnxruntime \ + onnxmltools \ + onnxoptimizer \ + onnxsim \ + multi_quantization \ + typeguard \ + numpy \ + pytest \ + graphviz + +RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \ + cd /workspace/icefall && \ + pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /workspace/icefall:$PYTHONPATH + +WORKDIR /workspace/icefall diff --git a/docs/source/docker/intro.rst b/docs/source/docker/intro.rst index f3d2b0727..5fc3fa4d5 100644 --- a/docs/source/docker/intro.rst +++ b/docs/source/docker/intro.rst @@ -34,6 +34,12 @@ which will give you something like below: .. code-block:: bash + "torch2.4.1-cuda12.4" + "torch2.4.1-cuda12.1" + "torch2.4.1-cuda11.8" + "torch2.4.0-cuda12.4" + "torch2.4.0-cuda12.1" + "torch2.4.0-cuda11.8" "torch2.3.1-cuda12.1" "torch2.3.1-cuda11.8" "torch2.2.2-cuda12.1" From d4b43236999da5314e889544524782fecafe8ddc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 7 Sep 2024 19:21:26 +0800 Subject: [PATCH 10/59] Fix github actions CI tests (#1744) --- .github/scripts/docker/generate_build_matrix.py | 9 +++++---- ...gigaspeech-pruned-transducer-stateless2-2022-05-12.sh | 4 ++-- .github/scripts/test-onnx-export.sh | 1 + .github/workflows/build-doc.yml | 2 ++ .github/workflows/run-gigaspeech-2022-05-13.yml | 4 +++- .../workflows/run-gigaspeech-zipformer-2023-10-17.yml | 4 ++-- ...librispeech-lstm-transducer-stateless2-2022-09-03.yml | 4 +++- .github/workflows/run-multi-corpora-zipformer.yml | 2 ++ .github/workflows/run-ptb-rnn-lm.yml | 4 +++- .github/workflows/run-swbd-conformer-ctc.yml | 2 ++ .../run-wenetspeech-pruned-transducer-stateless2.yml | 2 ++ .github/workflows/style_check.yml | 2 ++ .github/workflows/test-ncnn-export.yml | 2 ++ .github/workflows/test-onnx-export.yml | 2 ++ .github/workflows/test.yml | 2 +- 15 files changed, 34 insertions(+), 12 deletions(-) diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 492d3ed47..08281151e 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -48,10 +48,11 @@ def get_matrix(): version = "20240905" python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] - torch_version += ["1.13.0", "1.13.1"] - torch_version += ["2.0.0", "2.0.1"] - torch_version += ["2.1.0", "2.1.1", "2.1.2"] - torch_version += ["2.2.0", "2.2.1", "2.2.2"] + # torch_version += ["1.13.0", "1.13.1"] + # torch_version += ["2.0.0", "2.0.1"] + # torch_version += ["2.1.0", "2.1.1", "2.1.2"] + # torch_version += ["2.2.0", "2.2.1", "2.2.2"] + # Test only torch >= 2.3.0 torch_version += ["2.3.0", "2.3.1"] torch_version += ["2.4.0"] torch_version += ["2.4.1"] diff --git a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh index b61a9d7b6..c9e798a68 100755 --- a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh +++ b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh @@ -29,8 +29,8 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == ls -lh data/fbank ls -lh pruned_transducer_stateless2/exp - ln -s data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz - ln -s data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz + ln -sf data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz + ln -sf data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz log "Decoding dev and test" diff --git a/.github/scripts/test-onnx-export.sh b/.github/scripts/test-onnx-export.sh index fcfc11fa6..3252c37f1 100755 --- a/.github/scripts/test-onnx-export.sh +++ b/.github/scripts/test-onnx-export.sh @@ -25,6 +25,7 @@ popd log "Export via torch.jit.script()" ./zipformer/export.py \ + --use-averaged-model 0 \ --exp-dir $repo/exp \ --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index c622476f2..ca96e6de5 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -26,6 +26,8 @@ on: pull_request: types: [labeled] + workflow_dispatch: + concurrency: group: build_doc-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index 3121520c1..2c1d44fbf 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -33,6 +33,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: run_gigaspeech_2022_05_13-${{ github.ref }} cancel-in-progress: true @@ -119,7 +121,7 @@ jobs: find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 - name: Upload decoding results for gigaspeech pruned_transducer_stateless2 - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12 diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml index 87090e310..4ecc2aea0 100644 --- a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml +++ b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml @@ -42,7 +42,7 @@ concurrency: jobs: run_gigaspeech_2023_10_17_zipformer: - if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'zipformer' ||github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' runs-on: ${{ matrix.os }} strategy: matrix: @@ -133,7 +133,7 @@ jobs: find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for gigaspeech zipformer - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 501fae38c..6a3f4eb40 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -16,6 +16,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }} cancel-in-progress: true @@ -156,7 +158,7 @@ jobs: find modified_beam_search_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for lstm_transducer_stateless2 - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03 diff --git a/.github/workflows/run-multi-corpora-zipformer.yml b/.github/workflows/run-multi-corpora-zipformer.yml index 38f7eb908..84f9f3a0d 100644 --- a/.github/workflows/run-multi-corpora-zipformer.yml +++ b/.github/workflows/run-multi-corpora-zipformer.yml @@ -23,6 +23,8 @@ on: pull_request: types: [labeled] + workflow_dispatch: + concurrency: group: run_multi-corpora_zipformer-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml index f8d9c02c5..6e4077cf4 100644 --- a/.github/workflows/run-ptb-rnn-lm.yml +++ b/.github/workflows/run-ptb-rnn-lm.yml @@ -16,6 +16,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: run_ptb_rnn_lm_training-${{ github.ref }} cancel-in-progress: true @@ -64,7 +66,7 @@ jobs: ./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2 - name: Upload pretrained models - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule' with: name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb diff --git a/.github/workflows/run-swbd-conformer-ctc.yml b/.github/workflows/run-swbd-conformer-ctc.yml index 842691d38..b0178bedd 100644 --- a/.github/workflows/run-swbd-conformer-ctc.yml +++ b/.github/workflows/run-swbd-conformer-ctc.yml @@ -23,6 +23,8 @@ on: pull_request: types: [labeled] + workflow_dispatch: + concurrency: group: run-swbd-conformer_ctc-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml index 319a5558a..e76497ec3 100644 --- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml +++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml @@ -23,6 +23,8 @@ on: pull_request: types: [labeled] + workflow_dispatch: + concurrency: group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 1c37f13ed..0681ece60 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -24,6 +24,8 @@ on: branches: - master + workflow_dispatch: + concurrency: group: style_check-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/test-ncnn-export.yml b/.github/workflows/test-ncnn-export.yml index 5709f8ebb..ec419d65f 100644 --- a/.github/workflows/test-ncnn-export.yml +++ b/.github/workflows/test-ncnn-export.yml @@ -16,6 +16,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: test_ncnn_export-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/test-onnx-export.yml b/.github/workflows/test-onnx-export.yml index c05cde3ba..646ca0569 100644 --- a/.github/workflows/test-onnx-export.yml +++ b/.github/workflows/test-onnx-export.yml @@ -16,6 +16,8 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" + workflow_dispatch: + concurrency: group: test_onnx_export-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 659681b37..9eb7e403c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -105,7 +105,7 @@ jobs: cd ../zipformer pytest -v -s - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v4 with: path: egs/librispeech/ASR/zipformer/swoosh.pdf name: swoosh.pdf From 559c8a716039bc1f3da2a4d1487292830fd21f06 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Sep 2024 17:10:17 +0800 Subject: [PATCH 11/59] fixed a typo in `prepare.sh` for alimeeting recipes (#1747) --- egs/alimeeting/ASR/prepare.sh | 2 +- egs/alimeeting/ASR_v2/prepare.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 996a1da2d..55f9f019b 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -87,7 +87,7 @@ fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare musan manifest" # We assume that you have downloaded the musan corpus - # to data/musan + # to $dl_dir/musan if [ ! -f data/manifests/.musan_manifests.done ]; then log "It may take 6 minutes" mkdir -p data/manifests diff --git a/egs/alimeeting/ASR_v2/prepare.sh b/egs/alimeeting/ASR_v2/prepare.sh index 15c20692d..1881cd75c 100755 --- a/egs/alimeeting/ASR_v2/prepare.sh +++ b/egs/alimeeting/ASR_v2/prepare.sh @@ -65,7 +65,7 @@ fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Prepare musan manifest" # We assume that you have downloaded the musan corpus - # to data/musan + # to $dl_dir/musan mkdir -p data/manifests lhotse prepare musan $dl_dir/musan data/manifests fi From 2ff0bb6a884c8f5aafa48551fba8c7d0eeb15b96 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 8 Sep 2024 17:42:55 +0800 Subject: [PATCH 12/59] fix CI tests (#1748) --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9eb7e403c..c22f2edb5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -108,4 +108,4 @@ jobs: - uses: actions/upload-artifact@v4 with: path: egs/librispeech/ASR/zipformer/swoosh.pdf - name: swoosh.pdf + name: swoosh-${{ matrix.python-version }}-${{ matrix.torch-version }} From 65b8a6c730568ed12fccccb244e013f6ae3d7745 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Sep 2024 20:34:49 +0800 Subject: [PATCH 13/59] fixed wrong default value for the `alimeeting` recipe (#1750) --- .../pruned_transducer_stateless7/asr_datamodule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py index 6b56c8a6a..9da820315 100644 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/asr_datamodule.py @@ -82,7 +82,7 @@ class AlimeetingAsrDataModule: group.add_argument( "--manifest-dir", type=Path, - default=Path("data/manifests"), + default=Path("data/fbank"), help="Path to directory with train/valid/test cuts.", ) group.add_argument( @@ -327,9 +327,11 @@ class AlimeetingAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures() + ), return_cuts=True, ) sampler = DynamicBucketingSampler( From a394bf74742c0242f35a514e016df74d6ba42505 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Sep 2024 20:35:07 +0800 Subject: [PATCH 14/59] fixed gss scripts for `alimeeting` and `ami` recipes (#1749) --- egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh | 4 ++-- egs/ami/ASR/local/prepare_ami_gss.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh index 76db19832..bd25bc9e5 100755 --- a/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh +++ b/egs/alimeeting/ASR_v2/local/prepare_alimeeting_gss.sh @@ -58,7 +58,7 @@ if [ $stage -le 4 ]; then # for train, we use smaller context and larger batches to speed-up processing for JOB in $(seq $nj); do gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ - $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ + $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \ --bss-iterations 10 \ --context-duration 5.0 \ --use-garbage-class \ @@ -77,7 +77,7 @@ if [ $stage -le 5 ]; then for part in eval test; do for JOB in $(seq $nj); do gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ - $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \ $EXP_DIR/enhanced \ --bss-iterations 10 \ --context-duration 15.0 \ diff --git a/egs/ami/ASR/local/prepare_ami_gss.sh b/egs/ami/ASR/local/prepare_ami_gss.sh index d5422458b..414c22b12 100755 --- a/egs/ami/ASR/local/prepare_ami_gss.sh +++ b/egs/ami/ASR/local/prepare_ami_gss.sh @@ -58,7 +58,7 @@ if [ $stage -le 4 ]; then # for train, we use smaller context and larger batches to speed-up processing for JOB in $(seq $nj); do gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ - $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ + $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.$JOB.jsonl.gz $EXP_DIR/enhanced \ --bss-iterations 10 \ --context-duration 5.0 \ --use-garbage-class \ @@ -77,7 +77,7 @@ if [ $stage -le 5 ]; then for part in dev test; do for JOB in $(seq $nj); do gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ - $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.$JOB.jsonl.gz \ $EXP_DIR/enhanced \ --bss-iterations 10 \ --context-duration 15.0 \ From 329e34ac204bfedf7d4169ca4ccd295de7ff8aac Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 10 Sep 2024 19:29:19 +0800 Subject: [PATCH 15/59] Test export onnx models for multi-zh-hans (#1752) --- .github/scripts/multi-zh-hans.sh | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/.github/scripts/multi-zh-hans.sh b/.github/scripts/multi-zh-hans.sh index 427d8887b..e254419ff 100755 --- a/.github/scripts/multi-zh-hans.sh +++ b/.github/scripts/multi-zh-hans.sh @@ -16,6 +16,48 @@ log "pwd: $PWD" cd egs/multi_zh-hans/ASR +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-2023-9-2 +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo +cd exp +git lfs pull --include pretrained.pt +ln -s pretrained.pt epoch-99.pt +cd ../data/lang_bpe_2000 +ls -lh +git lfs pull --include L.pt L_disambig.pt Linv.pt bpe.model +git lfs pull --include "*.model" +ls -lh +popd + +log "--------------------------------------------" +log "Export non-streaming ONNX transducer models " +log "--------------------------------------------" +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +ls -lh $repo/exp + +./zipformer/onnx_pretrained.py \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav \ + $repo/test_wavs/TEST_MEETING_T0000000113.wav \ + $repo/test_wavs/TEST_MEETING_T0000000219.wav \ + $repo/test_wavs/TEST_MEETING_T0000000351.wav + +rm -rf $repo + repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 log "Downloading pre-trained model from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url From 6f1abd832dc290b62adfdd0f615010c2f3c274a5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 11 Sep 2024 21:04:52 +0800 Subject: [PATCH 16/59] Fix exporting streaming zipformer models. (#1755) --- .../ASR/zipformer/export-onnx-streaming.py | 2 +- egs/librispeech/ASR/zipformer/zipformer.py | 41 +++++++++++++++---- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index e5ceb3683..88c58f581 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -74,7 +74,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from onnxconverter_common import float16 from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -756,6 +755,7 @@ def main(): logging.info(f"Exported joiner to {joiner_filename}") if(params.fp16) : + from onnxconverter_common import float16 logging.info("Generate fp16 models") encoder = onnx.load(encoder_filename) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 69059287b..2a0ae0129 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -191,6 +191,7 @@ class Zipformer2(EncoderInterface): dim=encoder_dim[i], downsample=downsampling_factor[i], dropout=dropout, + causal=causal, ) encoders.append(encoder) @@ -198,7 +199,10 @@ class Zipformer2(EncoderInterface): self.encoders = nn.ModuleList(encoders) self.downsample_output = SimpleDownsample( - max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout, + causal=causal, ) def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: @@ -1217,11 +1221,16 @@ class DownsampledZipformer2Encoder(nn.Module): """ def __init__( - self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + self, + encoder: nn.Module, + dim: int, + downsample: int, + dropout: FloatLike, + causal: bool, ): super(DownsampledZipformer2Encoder, self).__init__() self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, downsample, dropout) + self.downsample = SimpleDownsample(dim, downsample, dropout, causal) self.num_layers = encoder.num_layers self.encoder = encoder self.upsample = SimpleUpsample(dim, downsample) @@ -1310,9 +1319,12 @@ class SimpleDownsample(torch.nn.Module): Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, channels: int, downsample: int, dropout: FloatLike): + def __init__( + self, channels: int, downsample: int, dropout: FloatLike, causal: bool + ): super(SimpleDownsample, self).__init__() + self.causal = causal self.bias = nn.Parameter(torch.zeros(downsample)) self.name = None # will be set from training code @@ -1333,9 +1345,18 @@ class SimpleDownsample(torch.nn.Module): # Pad to an exact multiple of self.downsample # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds + + if self.causal and torch.jit.is_tracing(): + assert ( + pad == 0 + ), f"pad should be zero for exporting streaming models. Given {pad}" + + # If we are exporting a streaming model, then we skip the if statement + if not self.causal or not torch.jit.is_tracing(): + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + + assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) @@ -1609,7 +1630,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module): k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim) + assert p.shape[-1] == num_heads * pos_head_dim, ( + p.shape[-1], + num_heads, + pos_head_dim, + ) q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. From 5c04c31292b87dc95fe0a9b498dc753f509ebda1 Mon Sep 17 00:00:00 2001 From: Yu Lianjie Date: Fri, 20 Sep 2024 12:38:52 +0800 Subject: [PATCH 17/59] fix open-commands path (#1714) --- egs/wenetspeech/KWS/prepare.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/wenetspeech/KWS/prepare.sh b/egs/wenetspeech/KWS/prepare.sh index dcc65fab4..e52e1a9d1 100755 --- a/egs/wenetspeech/KWS/prepare.sh +++ b/egs/wenetspeech/KWS/prepare.sh @@ -63,8 +63,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt pushd open-commands - ./script/prepare.sh --stage 1 --stop-stage 1 - ./script/prepare.sh --stage 3 --stop-stage 5 + ./scripts/prepare.sh --stage 1 --stop-stage 1 + ./scripts/prepare.sh --stage 3 --stop-stage 5 popd popd pushd data/fbank From d9844d847ffe5cf4c16136276ed000f7eb7bf314 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 9 Oct 2024 00:50:12 -0700 Subject: [PATCH 18/59] Update prepare.sh (#1768) --- egs/commonvoice/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh index 4e76ef041..200114a86 100755 --- a/egs/commonvoice/ASR/prepare.sh +++ b/egs/commonvoice/ASR/prepare.sh @@ -339,7 +339,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then # 2. chmod +x ./jq # 3. cp jq /usr/bin gunzip -c ${file} \ - | jq '.text' | sed 's/"//g' > $lang_dir/transcript_words.txt + | jq '.supervisions[].text' | sed 's/"//g' > $lang_dir/transcript_words.txt # Ensure space only appears once sed -i 's/\t/ /g' $lang_dir/transcript_words.txt From fbba712887e54adc1a8e6eb6cebf2bcd72f27b4c Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 12 Oct 2024 19:09:05 +0800 Subject: [PATCH 19/59] Fix issue with eval mode in ActivationDropoutLinear (#1770) * Fix issue with eval mode in ActivationDropoutLinear --------- Co-authored-by: Daniel Povey --- egs/librispeech/ASR/zipformer/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 2a40b8d64..d345c2931 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -1635,7 +1635,7 @@ class ActivationDropoutAndLinear(torch.nn.Module): self.dropout_shared_dim = dropout_shared_dim def forward(self, x: Tensor): - if torch.jit.is_scripting() or torch.jit.is_tracing(): + if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): if self.activation == "SwooshL": x = SwooshLForward(x) elif self.activation == "SwooshR": From 2653df5bda2c10e02ad3da404013fcad466e3567 Mon Sep 17 00:00:00 2001 From: zzasdf <68544676+zzasdf@users.noreply.github.com> Date: Sat, 12 Oct 2024 19:14:28 +0800 Subject: [PATCH 20/59] fix the mismatch in batch_idx_train (#1757) --- icefall/checkpoint.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index c83c56a53..308a06b1f 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -424,8 +424,12 @@ def average_checkpoints_with_averaged_model( state_dict_start = torch.load(filename_start, map_location=device) state_dict_end = torch.load(filename_end, map_location=device) + average_period = state_dict_start["average_period"] + batch_idx_train_start = state_dict_start["batch_idx_train"] + batch_idx_train_start = (batch_idx_train_start // average_period) * average_period batch_idx_train_end = state_dict_end["batch_idx_train"] + batch_idx_train_end = (batch_idx_train_end // average_period) * average_period interval = batch_idx_train_end - batch_idx_train_start assert interval > 0, interval weight_end = batch_idx_train_end / interval From f84270c93528f4b77b99ada9ac0c9f7fb231d6a4 Mon Sep 17 00:00:00 2001 From: KIM7AZEN <1556709171@qq.com> Date: Wed, 16 Oct 2024 17:19:24 +0800 Subject: [PATCH 21/59] fix the fixed num_splits (#1772) --- egs/wenetspeech/ASR/prepare.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 45912985b..74f213707 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -161,7 +161,7 @@ fi if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then log "Stage 11: Combine features for S" if [ ! -f data/fbank/cuts_S.jsonl.gz ]; then - pieces=$(find data/fbank/S_split_1000 -name "cuts_S.*.jsonl.gz") + pieces=$(find data/fbank/S_split_${num_splits} -name "cuts_S.*.jsonl.gz") lhotse combine $pieces data/fbank/cuts_S.jsonl.gz fi fi @@ -169,7 +169,7 @@ fi if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then log "Stage 12: Combine features for M" if [ ! -f data/fbank/cuts_M.jsonl.gz ]; then - pieces=$(find data/fbank/M_split_1000 -name "cuts_M.*.jsonl.gz") + pieces=$(find data/fbank/M_split_${num_splits} -name "cuts_M.*.jsonl.gz") lhotse combine $pieces data/fbank/cuts_M.jsonl.gz fi fi @@ -177,7 +177,7 @@ fi if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then log "Stage 13: Combine features for L" if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then - pieces=$(find data/fbank/L_split_1000 -name "cuts_L.*.jsonl.gz") + pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz") lhotse combine $pieces data/fbank/cuts_L.jsonl.gz fi fi From 693d84a3011b1bda51ac6f95c3002af93efa772d Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 21 Oct 2024 10:35:26 +0800 Subject: [PATCH 22/59] Add Consistency-Regularized CTC (#1766) * support consistency-regularized CTC * update arguments of cr-ctc * set default value of cr_loss_masked_scale to 1.0 * minor fix * refactor codes * update RESULTS.md --- egs/librispeech/ASR/README.md | 8 +- egs/librispeech/ASR/RESULTS.md | 310 +++++++++++++++++++++++++ egs/librispeech/ASR/zipformer/model.py | 123 +++++++++- egs/librispeech/ASR/zipformer/train.py | 95 +++++++- icefall/utils.py | 40 ++++ 5 files changed, 556 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 8b87ee19b..0dbfdc931 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer. | `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | | `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | | `zipformer-ctc` | Zipformer | Use auxiliary attention head | -| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe | +| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head (the latest recipe) | # MMI @@ -58,3 +58,9 @@ We place an additional Conv1d layer right after the input embedding layer. |------------------------------|-----------|---------------------------------------------------| | `conformer-mmi` | Conformer | | | `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding | + +# CR-CTC + +| | Encoder | Comment | +|------------------------------|--------------------|------------------------------| +| `zipformer` | Upgraded Zipformer | Could also be an auxiliary loss to improve transducer or CTC/AED (the latest recipe) | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index bc7d8a5ef..6a669f072 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,315 @@ ## Results +### zipformer (zipformer + pruned-transducer w/ CR-CTC) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### large-scale model, number of model parameters: 148824074, i.e., 148.8 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| greedy_search | 1.9 | 3.96 | --epoch 50 --avg 26 | +| modified_beam_search | 1.88 | 3.95 | --epoch 50 --avg 26 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large-cr-ctc-rnnt \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 1 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --ctc-loss-scale 0.1 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.02 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1400 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search; do + ./zipformer/decode.py \ + --epoch 50 \ + --avg 26 \ + --exp-dir zipformer/exp-large-cr-ctc-rnnt \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 1 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 300 \ + --decoding-method $m +done +``` + +### zipformer (zipformer + CR-CTC-AED) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| attention-decoder-rescoring-no-ngram | 1.96 | 4.08 | --epoch 50 --avg 20 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large-cr-ctc-aed \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --ctc-loss-scale 0.1 \ + --attention-decoder-loss-scale 0.9 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.02 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1200 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 20 \ + --exp-dir zipformer/exp-large-cr-ctc-aed/ \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 200 \ + --decoding-method attention-decoder-rescoring-no-ngram +done +``` + +### zipformer (zipformer + CR-CTC) + +See for more details. + +[zipformer](./zipformer) + +#### Non-streaming + +##### small-scale model, number of model parameters: 22118279, i.e., 22.1 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 | + +The training command using 2 32G-V100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# for non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-small/ \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --base-lr 0.04 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 850 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 25 \ + --exp-dir zipformer/exp-small \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,2,2,2,2 \ + --feedforward-dim 512,768,768,768,768,768 \ + --encoder-dim 192,256,256,256,256,256 \ + --encoder-unmasked-dim 192,192,192,192,192,192 \ + --max-duration 600 \ + --decoding-method $m +done +``` + +##### medium-scale model, number of model parameters: 64250603, i.e., 64.3 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 | + +The training command using 4 32G-V100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 700 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 24 \ + --exp-dir zipformer/exp \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --max-duration 600 \ + --decoding-method $m +done +``` + +##### large-scale model, number of model parameters: 147010094, i.e., 147.0 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|--------------------------------------|------------|------------|---------------------| +| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 | + +The training command using 2 80G-A100 GPUs is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +# For non-streaming model training: +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --enable-spec-aug 0 \ + --cr-loss-scale 0.2 \ + --time-mask-ratio 2.5 \ + --full-libri 1 \ + --max-duration 1400 \ + --master-port 12345 +``` + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in ctc-greedy-search; do + ./zipformer/ctc_decode.py \ + --epoch 50 \ + --avg 26 \ + --exp-dir zipformer/exp-large \ + --use-cr-ctc 1 \ + --use-ctc 1 \ + --use-transducer 0 \ + --use-attention-decoder 0 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --max-duration 600 \ + --decoding-method $m +done +``` + ### zipformer (zipformer + CTC/AED) See for more details. diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index bd1ed26d8..deebb2a75 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -24,7 +24,8 @@ import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos, make_pad_mask, time_warp +from lhotse.dataset import SpecAugment class AsrModel(nn.Module): @@ -181,6 +182,49 @@ class AsrModel(nn.Module): ) return ctc_loss + def forward_cr_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute CTC loss with consistency regularization loss. + Args: + encoder_out: + Encoder output, of shape (2 * N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (2 * N,). + targets: + Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC loss + ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C) + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C) + targets=targets.cpu(), + input_lengths=encoder_out_lens.cpu(), + target_lengths=target_lengths.cpu(), + reduction="sum", + ) + + # Compute consistency regularization loss + exchanged_targets = ctc_output.detach().chunk(2, dim=0) + exchanged_targets = torch.cat( + [exchanged_targets[1], exchanged_targets[0]], dim=0 + ) # exchange: [x1, x2] -> [x2, x1] + cr_loss = nn.functional.kl_div( + input=ctc_output, + target=exchanged_targets, + reduction="none", + log_target=True, + ) # (2 * N, T, C) + length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) + cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() + + return ctc_loss, cr_loss + def forward_transducer( self, encoder_out: torch.Tensor, @@ -296,7 +340,12 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + use_cr_ctc: bool = False, + use_spec_aug: bool = False, + spec_augment: Optional[SpecAugment] = None, + supervision_segments: Optional[torch.Tensor] = None, + time_warp_factor: Optional[int] = 80, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -316,9 +365,26 @@ class AsrModel(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + use_cr_ctc: + Whether use consistency-regularized CTC. + use_spec_aug: + Whether apply spec-augment manually, used only if use_cr_ctc is True. + spec_augment: + The SpecAugment instance that returns time masks, + used only if use_cr_ctc is True. + supervision_segments: + An int tensor of shape ``(S, 3)``. ``S`` is the number of + supervision segments that exist in ``features``. + Used only if use_cr_ctc is True. + time_warp_factor: + Parameter for the time warping; larger values mean more warping. + Set to ``None``, or less than ``1``, to disable. + Used only if use_cr_ctc is True. + Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) + Return the transducer losses, CTC loss, AED loss, + and consistency-regularization loss in form of + (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -334,6 +400,24 @@ class AsrModel(nn.Module): device = x.device + if use_cr_ctc: + assert self.use_ctc + if use_spec_aug: + assert spec_augment is not None and spec_augment.time_warp_factor < 1 + # Apply time warping before input duplicating + assert supervision_segments is not None + x = time_warp( + x, + time_warp_factor=time_warp_factor, + supervision_segments=supervision_segments, + ) + # Independently apply frequency masking and time masking to the two copies + x = spec_augment(x.repeat(2, 1, 1)) + else: + x = x.repeat(2, 1, 1) + x_lens = x_lens.repeat(2) + y = k2.ragged.cat([y, y], axis=0) + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -351,6 +435,9 @@ class AsrModel(nn.Module): am_scale=am_scale, lm_scale=lm_scale, ) + if use_cr_ctc: + simple_loss = simple_loss * 0.5 + pruned_loss = pruned_loss * 0.5 else: simple_loss = torch.empty(0) pruned_loss = torch.empty(0) @@ -358,14 +445,26 @@ class AsrModel(nn.Module): if self.use_ctc: # Compute CTC loss targets = y.values - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) + if not use_cr_ctc: + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + cr_loss = torch.empty(0) + else: + ctc_loss, cr_loss = self.forward_cr_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=targets, + target_lengths=y_lens, + ) + ctc_loss = ctc_loss * 0.5 + cr_loss = cr_loss * 0.5 else: ctc_loss = torch.empty(0) + cr_loss = torch.empty(0) if self.use_attention_decoder: attention_decoder_loss = self.attention_decoder.calc_att_loss( @@ -374,7 +473,9 @@ class AsrModel(nn.Module): ys=y.to(device), ys_lens=y_lens.to(device), ) + if use_cr_ctc: + attention_decoder_loss = attention_decoder_loss * 0.5 else: attention_decoder_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 9c1c7f5a7..c074c32ec 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -45,11 +45,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --max-duration 1000 It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` - - ctc loss & attention decoder loss, no transducer loss, - with `--use-transducer False --use-ctc True --use-attention-decoder True` + - transducer loss (default) + - ctc loss + - attention decoder loss + - cr-ctc loss (should use half the max-duration compared to regular ctc) """ @@ -72,6 +71,7 @@ from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut +from lhotse.dataset import SpecAugment from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import AsrModel @@ -304,6 +304,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use attention-decoder head.", ) + parser.add_argument( + "--use-cr-ctc", + type=str2bool, + default=False, + help="If True, use consistency-regularized CTC.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -449,6 +456,20 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--cr-loss-scale", + type=float, + default=0.2, + help="Scale for consistency-regularization loss.", + ) + + parser.add_argument( + "--time-mask-ratio", + type=float, + default=2.5, + help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.", + ) + parser.add_argument( "--attention-decoder-loss-scale", type=float, @@ -717,6 +738,24 @@ def get_model(params: AttributeDict) -> nn.Module: return model +def get_spec_augment(params: AttributeDict) -> SpecAugment: + num_frame_masks = int(10 * params.time_mask_ratio) + max_frames_mask_fraction = 0.15 * params.time_mask_ratio + logging.info( + f"num_frame_masks: {num_frame_masks}, " + f"max_frames_mask_fraction: {max_frames_mask_fraction}" + ) + spec_augment = SpecAugment( + time_warp_factor=0, # Do time warping in model.py + num_frame_masks=num_frame_masks, # default: 10 + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15 + ) + return spec_augment + + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -839,6 +878,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + spec_augment: Optional[SpecAugment] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -855,8 +895,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] @@ -874,14 +914,34 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y) + use_cr_ctc = params.use_cr_ctc + use_spec_aug = use_cr_ctc and is_training + if use_spec_aug: + supervision_intervals = batch["supervisions"] + supervision_segments = torch.stack( + [ + supervision_intervals["sequence_idx"], + supervision_intervals["start_frame"], + supervision_intervals["num_frames"], + ], + dim=1, + ) # shape: (S, 3) + else: + supervision_segments = None + with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + use_cr_ctc=use_cr_ctc, + use_spec_aug=use_spec_aug, + spec_augment=spec_augment, + supervision_segments=supervision_segments, + time_warp_factor=params.spec_aug_time_warp_factor, ) loss = 0.0 @@ -904,6 +964,8 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if use_cr_ctc: + loss += params.cr_loss_scale * cr_loss if params.use_attention_decoder: loss += params.attention_decoder_loss_scale * attention_decoder_loss @@ -922,6 +984,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_cr_ctc: + info["cr_loss"] = cr_loss.detach().cpu().item() if params.use_attention_decoder: info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() @@ -971,6 +1035,7 @@ def train_one_epoch( train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + spec_augment: Optional[SpecAugment] = None, model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, @@ -997,6 +1062,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + spec_augment: + The SpecAugment instance used only when use_cr_ctc is True. model_avg: The stored model averaged from the start of training. tb_writer: @@ -1043,6 +1110,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info @@ -1238,6 +1306,13 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + if params.use_cr_ctc: + assert params.use_ctc + assert not params.enable_spec_aug # we will do spec_augment in model.py + spec_augment = get_spec_augment(params) + else: + spec_augment = None + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -1360,6 +1435,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + spec_augment=spec_augment, ) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) @@ -1387,6 +1463,7 @@ def run(rank, world_size, args): train_dl=train_dl, valid_dl=valid_dl, scaler=scaler, + spec_augment=spec_augment, tb_writer=tb_writer, world_size=world_size, rank=rank, @@ -1452,6 +1529,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + spec_augment: Optional[SpecAugment] = None, ): from lhotse.dataset import find_pessimistic_batches @@ -1471,6 +1549,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, + spec_augment=spec_augment, ) loss.backward() optimizer.zero_grad() diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954de..b0a42cefa 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -21,6 +21,7 @@ import argparse import collections import logging import os +import random import re import subprocess from collections import defaultdict @@ -38,6 +39,7 @@ import sentencepiece as spm import torch import torch.distributed as dist import torch.nn as nn +from lhotse.dataset.signal_transforms import time_warp as time_warp_impl from pypinyin import lazy_pinyin, pinyin from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from torch.utils.tensorboard import SummaryWriter @@ -2271,3 +2273,41 @@ def num_tokens( if 0 in ans: num_tokens -= 1 return num_tokens + + +# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py +def time_warp( + features: torch.Tensor, + p: float = 0.9, + time_warp_factor: Optional[int] = 80, + supervision_segments: Optional[torch.Tensor] = None, +): + """Apply time warping on a batch of features + """ + if time_warp_factor is None or time_warp_factor < 1: + return features + assert len(features.shape) == 3, ( + "SpecAugment only supports batches of single-channel feature matrices." + ) + features = features.clone() + if supervision_segments is None: + # No supervisions - apply spec augment to full feature matrices. + for sequence_idx in range(features.size(0)): + if random.random() > p: + # Randomly choose whether this transform is applied + continue + features[sequence_idx] = time_warp_impl( + features[sequence_idx], factor=time_warp_factor + ) + else: + # Supervisions provided - we will apply time warping only on the supervised areas. + for sequence_idx, start_frame, num_frames in supervision_segments: + if random.random() > p: + # Randomly choose whether this transform is applied + continue + end_frame = start_frame + num_frames + features[sequence_idx, start_frame:end_frame] = time_warp_impl( + features[sequence_idx, start_frame:end_frame], factor=time_warp_factor + ) + + return features From e8b6b920c08978a7b9e10a5e6c1e436212de4e84 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 21 Oct 2024 11:30:14 +0800 Subject: [PATCH 23/59] A LibriTTS recipe on both ASR & Neural Codec Tasks (#1746) * added ASR & CODEC recipes for LibriTTS corpus --- .../ASR/zipformer/attention_decoder.py | 26 +- .../ASR/zipformer/export-onnx-streaming.py | 12 +- egs/librispeech/ASR/zipformer/export-onnx.py | 8 +- egs/libritts/ASR/README.md | 26 + egs/libritts/ASR/RESULTS.md | 58 + egs/libritts/ASR/local/compile_hlg.py | 1 + egs/libritts/ASR/local/compile_lg.py | 1 + .../ASR/local/compute_fbank_libritts.py | 160 ++ egs/libritts/ASR/local/compute_fbank_musan.py | 1 + .../convert_transcript_words_to_tokens.py | 1 + .../ASR/local/display_manifest_statistics.py | 341 ++++ egs/libritts/ASR/local/download_lm.py | 1 + egs/libritts/ASR/local/norm_text.py | 1 + egs/libritts/ASR/local/prepare_lang.py | 1 + egs/libritts/ASR/local/prepare_lang_bpe.py | 1 + egs/libritts/ASR/local/prepare_lang_fst.py | 1 + .../ASR/local/prepare_lm_training_data.py | 1 + egs/libritts/ASR/local/train_bpe_model.py | 1 + .../ASR/local/validate_bpe_lexicon.py | 1 + egs/libritts/ASR/local/validate_manifest.py | 71 + egs/libritts/ASR/prepare.sh | 194 +++ egs/libritts/ASR/prepare_lm.sh | 264 +++ egs/libritts/ASR/shared | 1 + egs/libritts/ASR/zipformer/.gitignore | 1 + egs/libritts/ASR/zipformer/asr_datamodule.py | 459 +++++ .../ASR/zipformer/attention_decoder.py | 1 + egs/libritts/ASR/zipformer/beam_search.py | 1 + egs/libritts/ASR/zipformer/ctc_decode.py | 992 +++++++++++ egs/libritts/ASR/zipformer/decode.py | 1086 ++++++++++++ egs/libritts/ASR/zipformer/decode_stream.py | 1 + egs/libritts/ASR/zipformer/decoder.py | 1 + .../ASR/zipformer/encoder_interface.py | 1 + egs/libritts/ASR/zipformer/export-onnx-ctc.py | 1 + .../zipformer/export-onnx-streaming-ctc.py | 1 + .../ASR/zipformer/export-onnx-streaming.py | 1 + egs/libritts/ASR/zipformer/export-onnx.py | 1 + egs/libritts/ASR/zipformer/export.py | 1 + .../ASR/zipformer/generate_averaged_model.py | 1 + egs/libritts/ASR/zipformer/jit_pretrained.py | 1 + .../ASR/zipformer/jit_pretrained_ctc.py | 1 + .../ASR/zipformer/jit_pretrained_streaming.py | 1 + egs/libritts/ASR/zipformer/joiner.py | 1 + egs/libritts/ASR/zipformer/label_smoothing.py | 1 + egs/libritts/ASR/zipformer/model.py | 1 + egs/libritts/ASR/zipformer/my_profile.py | 1 + egs/libritts/ASR/zipformer/onnx_check.py | 1 + egs/libritts/ASR/zipformer/onnx_decode.py | 326 ++++ .../onnx_pretrained-streaming-ctc.py | 1 + .../zipformer/onnx_pretrained-streaming.py | 1 + egs/libritts/ASR/zipformer/onnx_pretrained.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_H.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HL.py | 1 + .../ASR/zipformer/onnx_pretrained_ctc_HLG.py | 1 + .../onnx_pretrained_ctc_HLG_streaming.py | 1 + egs/libritts/ASR/zipformer/optim.py | 1 + egs/libritts/ASR/zipformer/pretrained.py | 1 + egs/libritts/ASR/zipformer/pretrained_ctc.py | 1 + egs/libritts/ASR/zipformer/scaling.py | 1 + .../ASR/zipformer/scaling_converter.py | 1 + .../ASR/zipformer/streaming_beam_search.py | 1 + .../ASR/zipformer/streaming_decode.py | 901 ++++++++++ egs/libritts/ASR/zipformer/subsampling.py | 1 + egs/libritts/ASR/zipformer/train.py | 1527 +++++++++++++++++ egs/libritts/ASR/zipformer/zipformer.py | 1 + .../CODEC/encodec/base_discriminators.py | 251 +++ egs/libritts/CODEC/encodec/binary.py | 161 ++ .../CODEC/encodec/codec_datamodule.py | 336 ++++ egs/libritts/CODEC/encodec/discriminators.py | 123 ++ egs/libritts/CODEC/encodec/encodec.py | 359 ++++ egs/libritts/CODEC/encodec/infer.py | 352 ++++ egs/libritts/CODEC/encodec/loss.py | 321 ++++ .../CODEC/encodec/modules/__init__.py | 20 + egs/libritts/CODEC/encodec/modules/conv.py | 334 ++++ egs/libritts/CODEC/encodec/modules/lstm.py | 27 + egs/libritts/CODEC/encodec/modules/norm.py | 28 + egs/libritts/CODEC/encodec/modules/seanet.py | 368 ++++ .../CODEC/encodec/modules/transformer.py | 141 ++ .../CODEC/encodec/quantization/__init__.py | 7 + egs/libritts/CODEC/encodec/quantization/ac.py | 311 ++++ .../CODEC/encodec/quantization/core_vq.py | 377 ++++ .../CODEC/encodec/quantization/distrib.py | 126 ++ egs/libritts/CODEC/encodec/quantization/vq.py | 121 ++ egs/libritts/CODEC/encodec/scheduler.py | 171 ++ egs/libritts/CODEC/encodec/train.py | 1188 +++++++++++++ egs/libritts/CODEC/encodec/utils.py | 1 + .../local/compute_spectrogram_libritts.py | 147 ++ .../local/display_manifest_statistics.py | 341 ++++ egs/libritts/CODEC/local/validate_manifest.py | 1 + egs/libritts/CODEC/prepare.sh | 78 + egs/libritts/CODEC/shared | 1 + 91 files changed, 12174 insertions(+), 17 deletions(-) create mode 100644 egs/libritts/ASR/README.md create mode 100644 egs/libritts/ASR/RESULTS.md create mode 120000 egs/libritts/ASR/local/compile_hlg.py create mode 120000 egs/libritts/ASR/local/compile_lg.py create mode 100755 egs/libritts/ASR/local/compute_fbank_libritts.py create mode 120000 egs/libritts/ASR/local/compute_fbank_musan.py create mode 120000 egs/libritts/ASR/local/convert_transcript_words_to_tokens.py create mode 100755 egs/libritts/ASR/local/display_manifest_statistics.py create mode 120000 egs/libritts/ASR/local/download_lm.py create mode 120000 egs/libritts/ASR/local/norm_text.py create mode 120000 egs/libritts/ASR/local/prepare_lang.py create mode 120000 egs/libritts/ASR/local/prepare_lang_bpe.py create mode 120000 egs/libritts/ASR/local/prepare_lang_fst.py create mode 120000 egs/libritts/ASR/local/prepare_lm_training_data.py create mode 120000 egs/libritts/ASR/local/train_bpe_model.py create mode 120000 egs/libritts/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/libritts/ASR/local/validate_manifest.py create mode 100755 egs/libritts/ASR/prepare.sh create mode 100755 egs/libritts/ASR/prepare_lm.sh create mode 120000 egs/libritts/ASR/shared create mode 100644 egs/libritts/ASR/zipformer/.gitignore create mode 100644 egs/libritts/ASR/zipformer/asr_datamodule.py create mode 120000 egs/libritts/ASR/zipformer/attention_decoder.py create mode 120000 egs/libritts/ASR/zipformer/beam_search.py create mode 100755 egs/libritts/ASR/zipformer/ctc_decode.py create mode 100755 egs/libritts/ASR/zipformer/decode.py create mode 120000 egs/libritts/ASR/zipformer/decode_stream.py create mode 120000 egs/libritts/ASR/zipformer/decoder.py create mode 120000 egs/libritts/ASR/zipformer/encoder_interface.py create mode 120000 egs/libritts/ASR/zipformer/export-onnx-ctc.py create mode 120000 egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py create mode 120000 egs/libritts/ASR/zipformer/export-onnx-streaming.py create mode 120000 egs/libritts/ASR/zipformer/export-onnx.py create mode 120000 egs/libritts/ASR/zipformer/export.py create mode 120000 egs/libritts/ASR/zipformer/generate_averaged_model.py create mode 120000 egs/libritts/ASR/zipformer/jit_pretrained.py create mode 120000 egs/libritts/ASR/zipformer/jit_pretrained_ctc.py create mode 120000 egs/libritts/ASR/zipformer/jit_pretrained_streaming.py create mode 120000 egs/libritts/ASR/zipformer/joiner.py create mode 120000 egs/libritts/ASR/zipformer/label_smoothing.py create mode 120000 egs/libritts/ASR/zipformer/model.py create mode 120000 egs/libritts/ASR/zipformer/my_profile.py create mode 120000 egs/libritts/ASR/zipformer/onnx_check.py create mode 100755 egs/libritts/ASR/zipformer/onnx_decode.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py create mode 120000 egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py create mode 120000 egs/libritts/ASR/zipformer/optim.py create mode 120000 egs/libritts/ASR/zipformer/pretrained.py create mode 120000 egs/libritts/ASR/zipformer/pretrained_ctc.py create mode 120000 egs/libritts/ASR/zipformer/scaling.py create mode 120000 egs/libritts/ASR/zipformer/scaling_converter.py create mode 120000 egs/libritts/ASR/zipformer/streaming_beam_search.py create mode 100755 egs/libritts/ASR/zipformer/streaming_decode.py create mode 120000 egs/libritts/ASR/zipformer/subsampling.py create mode 100755 egs/libritts/ASR/zipformer/train.py create mode 120000 egs/libritts/ASR/zipformer/zipformer.py create mode 100644 egs/libritts/CODEC/encodec/base_discriminators.py create mode 100644 egs/libritts/CODEC/encodec/binary.py create mode 100644 egs/libritts/CODEC/encodec/codec_datamodule.py create mode 100644 egs/libritts/CODEC/encodec/discriminators.py create mode 100644 egs/libritts/CODEC/encodec/encodec.py create mode 100755 egs/libritts/CODEC/encodec/infer.py create mode 100644 egs/libritts/CODEC/encodec/loss.py create mode 100644 egs/libritts/CODEC/encodec/modules/__init__.py create mode 100644 egs/libritts/CODEC/encodec/modules/conv.py create mode 100644 egs/libritts/CODEC/encodec/modules/lstm.py create mode 100644 egs/libritts/CODEC/encodec/modules/norm.py create mode 100644 egs/libritts/CODEC/encodec/modules/seanet.py create mode 100644 egs/libritts/CODEC/encodec/modules/transformer.py create mode 100644 egs/libritts/CODEC/encodec/quantization/__init__.py create mode 100644 egs/libritts/CODEC/encodec/quantization/ac.py create mode 100644 egs/libritts/CODEC/encodec/quantization/core_vq.py create mode 100644 egs/libritts/CODEC/encodec/quantization/distrib.py create mode 100644 egs/libritts/CODEC/encodec/quantization/vq.py create mode 100644 egs/libritts/CODEC/encodec/scheduler.py create mode 100755 egs/libritts/CODEC/encodec/train.py create mode 120000 egs/libritts/CODEC/encodec/utils.py create mode 100755 egs/libritts/CODEC/local/compute_spectrogram_libritts.py create mode 100755 egs/libritts/CODEC/local/display_manifest_statistics.py create mode 120000 egs/libritts/CODEC/local/validate_manifest.py create mode 100755 egs/libritts/CODEC/prepare.sh create mode 120000 egs/libritts/CODEC/shared diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py index 81682e87b..bff536f90 100644 --- a/egs/librispeech/ASR/zipformer/attention_decoder.py +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -236,7 +236,7 @@ class TransformerDecoder(nn.Module): causal_mask = subsequent_mask(x.shape[0], device=x.device) # (seq_len, seq_len) attn_mask = torch.logical_or( padding_mask.unsqueeze(1), # (batch, 1, seq_len) - torch.logical_not(causal_mask).unsqueeze(0) # (1, seq_len, seq_len) + torch.logical_not(causal_mask).unsqueeze(0), # (1, seq_len, seq_len) ) # (batch, seq_len, seq_len) if memory is not None: @@ -367,7 +367,9 @@ class MultiHeadAttention(nn.Module): self.num_heads = num_heads self.head_dim = attention_dim // num_heads assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, num_heads, attention_dim + self.head_dim, + num_heads, + attention_dim, ) self.dropout = dropout self.name = None # will be overwritten in training code; for diagnostics. @@ -437,15 +439,19 @@ class MultiHeadAttention(nn.Module): if key_padding_mask is not None: assert key_padding_mask.shape == (batch, src_len), key_padding_mask.shape attn_weights = attn_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"), + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), ) if attn_mask is not None: - assert ( - attn_mask.shape == (batch, 1, src_len) - or attn_mask.shape == (batch, tgt_len, src_len) + assert attn_mask.shape == (batch, 1, src_len) or attn_mask.shape == ( + batch, + tgt_len, + src_len, ), attn_mask.shape - attn_weights = attn_weights.masked_fill(attn_mask.unsqueeze(1), float("-inf")) + attn_weights = attn_weights.masked_fill( + attn_mask.unsqueeze(1), float("-inf") + ) attn_weights = attn_weights.view(batch * num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -456,7 +462,11 @@ class MultiHeadAttention(nn.Module): # (batch * head, tgt_len, head_dim) attn_output = torch.bmm(attn_weights, v) - assert attn_output.shape == (batch * num_heads, tgt_len, head_dim), attn_output.shape + assert attn_output.shape == ( + batch * num_heads, + tgt_len, + head_dim, + ), attn_output.shape attn_output = attn_output.transpose(0, 1).contiguous() attn_output = attn_output.view(tgt_len, batch, num_heads * head_dim) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 88c58f581..a35eb5287 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -487,6 +487,7 @@ def export_encoder_model_onnx( add_meta_data(filename=encoder_filename, meta_data=meta_data) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, decoder_filename: str, @@ -754,30 +755,31 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") - if(params.fp16) : + if params.fp16: from onnxconverter_common import float16 + logging.info("Generate fp16 models") encoder = onnx.load(encoder_filename) encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" - onnx.save(encoder_fp16,encoder_filename_fp16) + onnx.save(encoder_fp16, encoder_filename_fp16) decoder = onnx.load(decoder_filename) decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" - onnx.save(decoder_fp16,decoder_filename_fp16) + onnx.save(decoder_fp16, decoder_filename_fp16) joiner = onnx.load(joiner_filename) joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" - onnx.save(joiner_fp16,joiner_filename_fp16) + onnx.save(joiner_fp16, joiner_filename_fp16) # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection logging.info("Generate int8 quantization models") - + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" quantize_dynamic( model_input=encoder_filename, diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index ca3cbf0d5..a56a7a3e6 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -592,23 +592,23 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") - if(params.fp16) : + if params.fp16: logging.info("Generate fp16 models") encoder = onnx.load(encoder_filename) encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" - onnx.save(encoder_fp16,encoder_filename_fp16) + onnx.save(encoder_fp16, encoder_filename_fp16) decoder = onnx.load(decoder_filename) decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" - onnx.save(decoder_fp16,decoder_filename_fp16) + onnx.save(decoder_fp16, decoder_filename_fp16) joiner = onnx.load(joiner_filename) joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" - onnx.save(joiner_fp16,joiner_filename_fp16) + onnx.save(joiner_fp16, joiner_filename_fp16) # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection diff --git a/egs/libritts/ASR/README.md b/egs/libritts/ASR/README.md new file mode 100644 index 000000000..138f4ae80 --- /dev/null +++ b/egs/libritts/ASR/README.md @@ -0,0 +1,26 @@ +# Introduction + +LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. +The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. +The main differences from the LibriSpeech corpus are listed below: +1. The audio files are at 24kHz sampling rate. +2. The speech is split at sentence breaks. +3. Both original and normalized texts are included. +4. Contextual information (e.g., neighbouring sentences) can be extracted. +5. Utterances with significant background noise are excluded. +For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. + + +This recipe includes some different ASR models trained with [LibriTTS](https://openslr.org/60/). + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +| | Encoder | Decoder | +|---------------------------------------|---------------------|--------------------| +| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | + +The decoder is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/libritts/ASR/RESULTS.md b/egs/libritts/ASR/RESULTS.md new file mode 100644 index 000000000..574f81eb6 --- /dev/null +++ b/egs/libritts/ASR/RESULTS.md @@ -0,0 +1,58 @@ +# Results + +## zipformer (zipformer + pruned stateless transducer) + +See for more details. + +[zipformer](./zipformer) + +### Non-streaming + +#### normal-scaled model, number of model parameters: 65549011, i.e., 65.55 M + +You can find a pretrained model, training logs, decoding logs, and decoding results at: + + +You can use to deploy it. + +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 2.83 | 5.91 | --epoch 30 --avg 5 | +| modified_beam_search | 2.80 | 5.87 | --epoch 30 --avg 5 | +| fast_beam_search | 2.87 | 5.86 | --epoch 30 --avg 5 | +| greedy_search | 2.76 | 5.68 | --epoch 40 --avg 16| +| modified_beam_search | 2.74 | 5.66 | --epoch 40 --avg 16| +| fast_beam_search | 2.75 | 5.67 | --epoch 40 --avg 16| +| greedy_search | 2.74 | 5.67 | --epoch 50 --avg 30| +| modified_beam_search | 2.73 | 5.58 | --epoch 50 --avg 30| +| fast_beam_search | 2.78 | 5.61 | --epoch 50 --avg 30| + + +The training command is: +```bash +export CUDA_VISIBLE_DEVICES="0,1" +./zipformer/train.py \ + --world-size 2 \ + --num-epochs 50 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 0 \ + --full-libri 1 \ + --max-duration 3600 +``` +This was used on 2 Nvidia A800 GPUs, you'll need to adjust the `CUDA_VISIBLE_DEVICES`, `--world-size` and `--max-duration` according to your hardware. + +The decoding command is: +```bash +export CUDA_VISIBLE_DEVICES="0" +for m in greedy_search modified_beam_search fast_beam_search; do + ./zipformer/decode.py \ + --epoch 50 \ + --avg 30 \ + --use-averaged-model 1 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method $m +done +``` diff --git a/egs/libritts/ASR/local/compile_hlg.py b/egs/libritts/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/libritts/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/compile_lg.py b/egs/libritts/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/libritts/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/compute_fbank_libritts.py b/egs/libritts/ASR/local/compute_fbank_libritts.py new file mode 100755 index 000000000..b6e2a4c43 --- /dev/null +++ b/egs/libritts/ASR/local/compute_fbank_libritts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao,) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriTTS dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""", + ) + + return parser.parse_args() + + +def compute_fbank_libritts( + dataset: Optional[str] = None, + sampling_rate: int = 24000, + perturb_speed: Optional[bool] = True, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(32, os.cpu_count()) + + num_mel_bins = 80 + + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = "libritts" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if sampling_rate != 24000: + logging.info(f"Resampling audio to {sampling_rate}Hz") + cut_set = cut_set.resample(sampling_rate) + if "train" in partition: + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + compute_fbank_libritts( + dataset=args.dataset, + sampling_rate=args.sampling_rate, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/libritts/ASR/local/compute_fbank_musan.py b/egs/libritts/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/libritts/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py b/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py new file mode 120000 index 000000000..2ce13fd69 --- /dev/null +++ b/egs/libritts/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/display_manifest_statistics.py b/egs/libritts/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..ddd022c96 --- /dev/null +++ b/egs/libritts/ASR/local/display_manifest_statistics.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + paths = [ + "./data/fbank/libritts_cuts_train-clean-100.jsonl.gz", + "./data/fbank/libritts_cuts_train-clean-360.jsonl.gz", + "./data/fbank/libritts_cuts_train-other-500.jsonl.gz", + "./data/fbank/libritts_cuts_dev-clean.jsonl.gz", + "./data/fbank/libritts_cuts_dev-other.jsonl.gz", + "./data/fbank/libritts_cuts_test-clean.jsonl.gz", + "./data/fbank/libritts_cuts_test-other.jsonl.gz", + ] + for path in paths: + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +./data/fbank/libritts_cuts_train-clean-100.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 33236 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 53:47:18 _ +________________________________________ +_ mean _ 5.8 _ +________________________________________ +_ std _ 4.6 _ +________________________________________ +_ min _ 0.2 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.5 _ +________________________________________ +_ 75% _ 7.9 _ +________________________________________ +_ 99% _ 21.4 _ +________________________________________ +_ 99.5% _ 23.7 _ +________________________________________ +_ 99.9% _ 27.8 _ +________________________________________ +_ max _ 33.2 _ +________________________________________ +_ Recordings available: _ 33236 _ +________________________________________ +_ Features available: _ 33236 _ +________________________________________ +_ Supervisions available: _ 33236 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 53:47:18 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 53:47:18 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/fbank/libritts_cuts_train-clean-360.jsonl.gz statistics: +_________________________________________ +_ Cuts count: _ 116500 _ +_________________________________________ +_ Total duration (hh:mm:ss) _ 191:17:42 _ +_________________________________________ +_ mean _ 5.9 _ +_________________________________________ +_ std _ 4.6 _ +_________________________________________ +_ min _ 0.1 _ +_________________________________________ +_ 25% _ 2.4 _ +_________________________________________ +_ 50% _ 4.6 _ +_________________________________________ +_ 75% _ 8.1 _ +_________________________________________ +_ 99% _ 21.3 _ +_________________________________________ +_ 99.5% _ 23.4 _ +_________________________________________ +_ 99.9% _ 27.4 _ +_________________________________________ +_ max _ 40.4 _ +_________________________________________ +_ Recordings available: _ 116500 _ +_________________________________________ +_ Features available: _ 116500 _ +_________________________________________ +_ Supervisions available: _ 116500 _ +_________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +___________________________________________________________________ +_ Total speech duration _ 191:17:42 _ 100.00% of recording _ +___________________________________________________________________ +_ Total speaking time duration _ 191:17:42 _ 100.00% of recording _ +___________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +___________________________________________________________________ + +./data/fbank/libritts_cuts_train-other-500.jsonl.gz statistics: +_________________________________________ +_ Cuts count: _ 205043 _ +_________________________________________ +_ Total duration (hh:mm:ss) _ 310:04:36 _ +_________________________________________ +_ mean _ 5.4 _ +_________________________________________ +_ std _ 4.4 _ +_________________________________________ +_ min _ 0.1 _ +_________________________________________ +_ 25% _ 2.3 _ +_________________________________________ +_ 50% _ 4.2 _ +_________________________________________ +_ 75% _ 7.3 _ +_________________________________________ +_ 99% _ 20.6 _ +_________________________________________ +_ 99.5% _ 22.8 _ +_________________________________________ +_ 99.9% _ 27.4 _ +_________________________________________ +_ max _ 43.9 _ +_________________________________________ +_ Recordings available: _ 205043 _ +_________________________________________ +_ Features available: _ 205043 _ +_________________________________________ +_ Supervisions available: _ 205043 _ +_________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +___________________________________________________________________ +_ Total speech duration _ 310:04:36 _ 100.00% of recording _ +___________________________________________________________________ +_ Total speaking time duration _ 310:04:36 _ 100.00% of recording _ +___________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +___________________________________________________________________ + +./data/fbank/libritts_cuts_dev-clean.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 5736 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 08:58:13 _ +________________________________________ +_ mean _ 5.6 _ +________________________________________ +_ std _ 4.3 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.4 _ +________________________________________ +_ 75% _ 7.8 _ +________________________________________ +_ 99% _ 19.9 _ +________________________________________ +_ 99.5% _ 21.9 _ +________________________________________ +_ 99.9% _ 26.3 _ +________________________________________ +_ max _ 30.1 _ +________________________________________ +_ Recordings available: _ 5736 _ +________________________________________ +_ Features available: _ 5736 _ +________________________________________ +_ Supervisions available: _ 5736 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 08:58:13 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 08:58:13 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/fbank/libritts_cuts_dev-other.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 4613 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 06:25:52 _ +________________________________________ +_ mean _ 5.0 _ +________________________________________ +_ std _ 4.1 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.2 _ +________________________________________ +_ 50% _ 3.8 _ +________________________________________ +_ 75% _ 6.5 _ +________________________________________ +_ 99% _ 19.7 _ +________________________________________ +_ 99.5% _ 24.5 _ +________________________________________ +_ 99.9% _ 31.0 _ +________________________________________ +_ max _ 32.6 _ +________________________________________ +_ Recordings available: _ 4613 _ +________________________________________ +_ Features available: _ 4613 _ +________________________________________ +_ Supervisions available: _ 4613 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 06:25:52 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 06:25:52 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/fbank/libritts_cuts_test-clean.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 4837 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 08:34:09 _ +________________________________________ +_ mean _ 6.4 _ +________________________________________ +_ std _ 5.1 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.8 _ +________________________________________ +_ 75% _ 8.9 _ +________________________________________ +_ 99% _ 22.6 _ +________________________________________ +_ 99.5% _ 24.4 _ +________________________________________ +_ 99.9% _ 29.6 _ +________________________________________ +_ max _ 36.7 _ +________________________________________ +_ Recordings available: _ 4837 _ +________________________________________ +_ Features available: _ 4837 _ +________________________________________ +_ Supervisions available: _ 4837 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 08:34:09 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 08:34:09 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/fbank/libritts_cuts_test-other.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 5120 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 06:41:31 _ +________________________________________ +_ mean _ 4.7 _ +________________________________________ +_ std _ 3.8 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 1.8 _ +________________________________________ +_ 50% _ 3.6 _ +________________________________________ +_ 75% _ 6.5 _ +________________________________________ +_ 99% _ 17.8 _ +________________________________________ +_ 99.5% _ 20.4 _ +________________________________________ +_ 99.9% _ 23.8 _ +________________________________________ +_ max _ 27.3 _ +________________________________________ +_ Recordings available: _ 5120 _ +________________________________________ +_ Features available: _ 5120 _ +________________________________________ +_ Supervisions available: _ 5120 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 06:41:31 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 06:41:31 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ +""" diff --git a/egs/libritts/ASR/local/download_lm.py b/egs/libritts/ASR/local/download_lm.py new file mode 120000 index 000000000..c9668bd2d --- /dev/null +++ b/egs/libritts/ASR/local/download_lm.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/download_lm.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/norm_text.py b/egs/libritts/ASR/local/norm_text.py new file mode 120000 index 000000000..dea3c051f --- /dev/null +++ b/egs/libritts/ASR/local/norm_text.py @@ -0,0 +1 @@ +../../../libriheavy/ASR/local/norm_text.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang.py b/egs/libritts/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/libritts/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang_bpe.py b/egs/libritts/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/libritts/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lang_fst.py b/egs/libritts/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/libritts/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/prepare_lm_training_data.py b/egs/libritts/ASR/local/prepare_lm_training_data.py new file mode 120000 index 000000000..abc00d421 --- /dev/null +++ b/egs/libritts/ASR/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/train_bpe_model.py b/egs/libritts/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/libritts/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/validate_bpe_lexicon.py b/egs/libritts/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/libritts/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/libritts/ASR/local/validate_manifest.py b/egs/libritts/ASR/local/validate_manifest.py new file mode 100755 index 000000000..abd4da88a --- /dev/null +++ b/egs/libritts/ASR/local/validate_manifest.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright 2022-2024 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao,) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/libritts_cuts_train-all-shuf.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest +from lhotse.dataset.speech_recognition import validate_for_asr + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest(manifest) + assert isinstance(cut_set, CutSet) + + validate_for_asr(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/libritts/ASR/prepare.sh b/egs/libritts/ASR/prepare.sh new file mode 100755 index 000000000..9d9ce8f87 --- /dev/null +++ b/egs/libritts/ASR/prepare.sh @@ -0,0 +1,194 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 +sampling_rate=16000 +nj=32 +perturb_speed=true +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download LM" # we directly use the librispeech lm here + mkdir -p $dl_dir/lm + if [ ! -e $dl_dir/lm/.done ]; then + ./local/download_lm.py --out-dir=$dl_dir/lm + touch $dl_dir/lm/.done + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriTTS, + # you can create a symlink + # + # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS + # + if [ ! -d $dl_dir/LibriTTS ]; then + lhotse download libritts $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/musan + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriTTS manifest" + # We assume that you have downloaded the LibriTTS corpus + # to $dl_dir/LibriTTS + mkdir -p data/manifests + if [ ! -e data/manifests/.libritts.done ]; then + lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests + touch data/manifests/.libritts.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + if [ ! -f data/manifests/.musan_manifests.done ]; then + log "It may take 6 minutes" + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan_manifests.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute Fbank for LibriTTS" + mkdir -p data/fbank + if [ ! -e data/fbank/.libritts.done ]; then + ./local/compute_fbank_libritts.py \ + --sampling-rate $sampling_rate \ + --perturb-speed $perturb_speed + touch data/fbank/.libritts.done + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 and + # train-other-500 together to form the training set. + if [ ! -f data/fbank/libritts_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/fbank/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/libritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/libritts_cuts_train-all-shuf.jsonl.gz + fi + + if [ ! -e data/fbank/.libritts-validated.done ]; then + log "Validating data/fbank for LibriTTS" + ./local/validate_manifest.py \ + data/fbank/libritts_cuts_train-all-shuf.jsonl.gz + touch data/fbank/.libritts-validated.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + if [ ! -f data/fbank/.msuan.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_musan.py + touch data/fbank/.msuan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Train BPE model for normalized text" + + if [ ! -f data/text ]; then + gunzip -c data/manifests/libritts_supervisions_train-clean-100.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py > data/text + + gunzip -c data/manifests/libritts_supervisions_train-clean-360.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py >> data/text + + gunzip -c data/manifests/libritts_supervisions_train-other-500.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py >> data/text + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + cp data/text $lang_dir/text + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/text + fi + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + if [ ! -f $dl_dir/lm/librispeech-lexicon.txt ]; then + log "No lexicon file in $dl_dir/lm, please run :" + log "prepare.sh --stage -1 --stop-stage -1" + exit -1 + fi + + if [ ! -f $lang_dir/lexicon.txt ]; then + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi +fi diff --git a/egs/libritts/ASR/prepare_lm.sh b/egs/libritts/ASR/prepare_lm.sh new file mode 100755 index 000000000..1c690983b --- /dev/null +++ b/egs/libritts/ASR/prepare_lm.sh @@ -0,0 +1,264 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +# This script generate Ngram LM / NNLM and related files that needed by decoding. + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz +# + +. prepare.sh --stage -1 --stop-stage 6 || exit 1 + +log "Running prepare_lm.sh" + +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare BPE based lexicon." + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare word level G" + # We assume you have installed kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + fi + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py \ + --lang-dir $lang_dir \ + --ngram-G ./data/lm/G_3_gram.fst.txt + fi + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare token level ngram G" + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + for ngram in 2 3 4 5; do + if [ ! -f $lang_dir/${ngram}gram.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order ${ngram} \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/${ngram}gram.arpa + fi + + if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=${ngram} \ + $lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt + fi + done + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate NNLM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate NNLM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + gunzip -c data/manifests/libritts_supervisions_dev-clean.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py > $out_dir/valid.txt + + gunzip -c data/manifests/libritts_supervisions_dev-other.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py >> $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Generate NNLM test data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/test.txt ]; then + gunzip -c data/manifests/libritts_supervisions_test-clean.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py > $out_dir/test.txt + + gunzip -c data/manifests/libritts_supervisions_test-other.jsonl.gz \ + | jq ".text" | sed 's/"//g' \ + | ./local/norm_text.py >> $out_dir/test.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Sort NNLM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/libritts/ASR/shared b/egs/libritts/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/libritts/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/.gitignore b/egs/libritts/ASR/zipformer/.gitignore new file mode 100644 index 000000000..e47ac1582 --- /dev/null +++ b/egs/libritts/ASR/zipformer/.gitignore @@ -0,0 +1 @@ +swoosh.pdf diff --git a/egs/libritts/ASR/zipformer/asr_datamodule.py b/egs/libritts/ASR/zipformer/asr_datamodule.py new file mode 100644 index 000000000..dab834303 --- /dev/null +++ b/egs/libritts/ASR/zipformer/asr_datamodule.py @@ -0,0 +1,459 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriTTSAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. libritts test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""When enabled, use the entire LibriTTS training set. + Otherwise, use the 100h subset.""", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=( + OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)() + ), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) diff --git a/egs/libritts/ASR/zipformer/attention_decoder.py b/egs/libritts/ASR/zipformer/attention_decoder.py new file mode 120000 index 000000000..384e1b95e --- /dev/null +++ b/egs/libritts/ASR/zipformer/attention_decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/attention_decoder.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/beam_search.py b/egs/libritts/ASR/zipformer/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/libritts/ASR/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/ctc_decode.py b/egs/libritts/ASR/zipformer/ctc_decode.py new file mode 100755 index 000000000..d77aa5962 --- /dev/null +++ b/egs/libritts/ASR/zipformer/ctc_decode.py @@ -0,0 +1,992 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) ctc-greedy-search +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-greedy-search + +(2) ctc-decoding +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(3) 1best +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method 1best + +(4) nbest +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --decoding-method nbest + +(5) nbest-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(6) whole-lattice-rescoring +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --max-duration 600 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring + +(7) attention-decoder-rescoring-no-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(8) attention-decoder-rescoring-with-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriTTSAsrDataModule +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params, normalize_text + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + ctc_greedy_search, + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder_no_ngram, + rescore_with_attention_decoder_with_ngram, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (3) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (4) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (5) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (6) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + - (8) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + rescored lattice, rescore them with the attention decoder. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=1.0, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.6, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + if params.decoding_method == "ctc-greedy-search": + hyps = ctc_greedy_search(ctc_output, encoder_out_lens) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-greedy-search" + return {key: hyps} + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} # note: returns words + + if params.decoding_method == "attention-decoder-rescoring-no-ngram": + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + ans = dict() + for a_scale_str, best_path in best_path_dict.items(): + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest-scale-{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no-rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no-rescore_nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} # note: returns BPE tokens + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "attention-decoder-rescoring-with-ngram": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + best_path_dict = rescore_with_attention_decoder_with_ngram( + lattice=rescored_lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + if params.decoding_method in ( + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", + ): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}_{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "ctc-greedy-search", + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + "attention-decoder-rescoring-no-ngram", + "attention-decoder-rescoring-with-ngram", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + params.eos_id = 1 + params.sos_id = 1 + + if params.decoding_method in [ + "ctc-greedy-search", + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method in [ + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libritts = LibriTTSAsrDataModule(args) + + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) + + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + test_other_dl = libritts.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/decode.py b/egs/libritts/ASR/zipformer/decode.py new file mode 100755 index 000000000..759d9d50a --- /dev/null +++ b/egs/libritts/ASR/zipformer/decode.py @@ -0,0 +1,1086 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriTTSAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from lhotse import set_caching_enabled +from train import add_model_arguments, get_model, get_params, normalize_text + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) + prefix = f"{params.decoding_method}" + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + prefix += f"_beam-{params.beam}" + prefix += f"_max-contexts-{params.max_contexts}" + prefix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + prefix += f"_num-paths-{params.num_paths}" + prefix += f"_nbest-scale-{params.nbest_scale}" + if "LG" in params.decoding_method: + prefix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + + return {prefix: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix += f"_beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"_context-score-{params.context_score}" + return {prefix: hyps} + else: + prefix += f"_beam-size-{params.beam_size}" + return {prefix: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}_avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"_nbest-scale-{params.nbest_scale}" + params.suffix += f"_num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"_context-{params.context_size}" + params.suffix += f"_max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"_{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "_use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append((sp.encode(line.strip()), 0.0)) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + libritts = LibriTTSAsrDataModule(args) + + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) + + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + test_other_dl = libritts.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/decode_stream.py b/egs/libritts/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/libritts/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/decoder.py b/egs/libritts/ASR/zipformer/decoder.py new file mode 120000 index 000000000..5a8018680 --- /dev/null +++ b/egs/libritts/ASR/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/encoder_interface.py b/egs/libritts/ASR/zipformer/encoder_interface.py new file mode 120000 index 000000000..653c5b09a --- /dev/null +++ b/egs/libritts/ASR/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-ctc.py b/egs/libritts/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 000000000..f9d756352 --- /dev/null +++ b/egs/libritts/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py b/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py new file mode 120000 index 000000000..652346001 --- /dev/null +++ b/egs/libritts/ASR/zipformer/export-onnx-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx-streaming.py b/egs/libritts/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/libritts/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export-onnx.py b/egs/libritts/ASR/zipformer/export-onnx.py new file mode 120000 index 000000000..70a15683c --- /dev/null +++ b/egs/libritts/ASR/zipformer/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/export.py b/egs/libritts/ASR/zipformer/export.py new file mode 120000 index 000000000..dfc1bec08 --- /dev/null +++ b/egs/libritts/ASR/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/generate_averaged_model.py b/egs/libritts/ASR/zipformer/generate_averaged_model.py new file mode 120000 index 000000000..5a015ee6c --- /dev/null +++ b/egs/libritts/ASR/zipformer/generate_averaged_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/generate_averaged_model.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained.py b/egs/libritts/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/libritts/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py b/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 000000000..9a8da5844 --- /dev/null +++ b/egs/libritts/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py b/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/libritts/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/joiner.py b/egs/libritts/ASR/zipformer/joiner.py new file mode 120000 index 000000000..5b8a36332 --- /dev/null +++ b/egs/libritts/ASR/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/label_smoothing.py b/egs/libritts/ASR/zipformer/label_smoothing.py new file mode 120000 index 000000000..175c633cc --- /dev/null +++ b/egs/libritts/ASR/zipformer/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/label_smoothing.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/model.py b/egs/libritts/ASR/zipformer/model.py new file mode 120000 index 000000000..cd7e07d72 --- /dev/null +++ b/egs/libritts/ASR/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/my_profile.py b/egs/libritts/ASR/zipformer/my_profile.py new file mode 120000 index 000000000..3a90b2628 --- /dev/null +++ b/egs/libritts/ASR/zipformer/my_profile.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/my_profile.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_check.py b/egs/libritts/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_decode.py b/egs/libritts/ASR/zipformer/onnx_decode.py new file mode 100755 index 000000000..6f09cc8f7 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_decode.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX exported models and uses them to decode the test sets. + +We use the pre-trained model from +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained.pt" + +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./zipformer/export-onnx.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --causal False + +It will generate the following 3 files inside $repo/exp: + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +2. Run this file + +./zipformer/onnx_decode.py \ + --exp-dir $repo/exp \ + --max-duration 600 \ + --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ + --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ + --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ +""" + + +import argparse +import logging +import time +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from asr_datamodule import LibriTTSAsrDataModule +from k2 import SymbolTable +from onnx_pretrained import OnnxModel, greedy_search +from train import normalize_text + +from icefall.utils import setup_logger, store_transcripts, write_error_stats + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="Valid values are greedy_search and modified_beam_search", + ) + + return parser + + +def decode_one_batch( + model: OnnxModel, token_table: SymbolTable, batch: dict +) -> List[List[str]]: + """Decode one batch and return the result. + Currently it only greedy_search is supported. + + Args: + model: + The neural model. + token_table: + The token table. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + Return the decoded results for each utterance. + """ + feature = batch["inputs"] + assert feature.ndim == 3 + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(dtype=torch.int64) + + encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) + + hyps = greedy_search( + model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens + ) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + hyps = [token_ids_to_words(h).split() for h in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + model: nn.Module, + token_table: SymbolTable, +) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + model: + The neural model. + token_table: + The token table. + + Returns: + - A list of tuples. Each tuple contains three elements: + - cut_id, + - reference transcript, + - predicted result. + - The total duration (in seconds) of the dataset. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 10 + total_duration = 0 + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) + + hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return results, total_duration + + +def save_results( + res_dir: Path, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], +): + recog_path = res_dir / f"recogs-{test_set_name}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = res_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + errs_info = res_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("WER", file=f) + print(wer, file=f) + + s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert ( + args.decoding_method == "greedy_search" + ), "Only supports greedy_search currently." + res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" + + setup_logger(f"{res_dir}/log-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + logging.info(f"Device: {device}") + + token_table = SymbolTable.from_file(args.tokens) + + logging.info(vars(args)) + + logging.info("About to create model") + model = OnnxModel( + encoder_model_filename=args.encoder_model_filename, + decoder_model_filename=args.decoder_model_filename, + joiner_model_filename=args.joiner_model_filename, + ) + + # we need cut ids to display recognition results. + args.return_cuts = True + libritts = LibriTTSAsrDataModule(args) + + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) + + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + test_other_dl = libritts.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + start_time = time.time() + results, total_duration = decode_dataset( + dl=test_dl, model=model, token_table=token_table + ) + end_time = time.time() + elapsed_seconds = end_time - start_time + rtf = elapsed_seconds / total_duration + + logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") + logging.info(f"Wave duration: {total_duration:.3f} s") + logging.info( + f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" + ) + + save_results(res_dir=res_dir, test_set_name=test_set, results=results) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py new file mode 120000 index 000000000..d623a8462 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming-ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained.py b/egs/libritts/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py new file mode 120000 index 000000000..a3183ebf6 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py new file mode 120000 index 000000000..a4fd76ac2 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py new file mode 120000 index 000000000..f805e3761 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py new file mode 120000 index 000000000..8343d5079 --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py new file mode 120000 index 000000000..3568e7cab --- /dev/null +++ b/egs/libritts/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG_streaming.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/optim.py b/egs/libritts/ASR/zipformer/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/libritts/ASR/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/pretrained.py b/egs/libritts/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/libritts/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/pretrained_ctc.py b/egs/libritts/ASR/zipformer/pretrained_ctc.py new file mode 120000 index 000000000..c2f6f6fc3 --- /dev/null +++ b/egs/libritts/ASR/zipformer/pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/scaling.py b/egs/libritts/ASR/zipformer/scaling.py new file mode 120000 index 000000000..6f398f431 --- /dev/null +++ b/egs/libritts/ASR/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/scaling_converter.py b/egs/libritts/ASR/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/libritts/ASR/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/streaming_beam_search.py b/egs/libritts/ASR/zipformer/streaming_beam_search.py new file mode 120000 index 000000000..b1ed54557 --- /dev/null +++ b/egs/libritts/ASR/zipformer/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/streaming_beam_search.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/streaming_decode.py b/egs/libritts/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..b21018788 --- /dev/null +++ b/egs/libritts/ASR/zipformer/streaming_decode.py @@ -0,0 +1,901 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# Copyright 2024 The Chinese Univ. of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import LibriTTSAsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet, set_caching_enabled +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params, normalize_text + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--label", + type=str, + default="", + help="""Extra label of the decoding run.""", + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--skip-scoring", + type=str2bool, + default=False, + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_asr_output( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save text produced by ASR. + """ + for key, results in results_dict.items(): + recogs_filename = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recogs_filename, texts=results) + logging.info(f"The transcripts are stored in {recogs_filename}") + + +def save_wer_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + """ + Save WER and per-utterance word alignments. + """ + test_set_wers = dict() + for key, results in results_dict.items(): + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w", encoding="utf8") as fd: + wer = write_error_stats( + fd, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info(f"Wrote detailed error stats to {errs_filename}") + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + + wer_filename = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(wer_filename, "w", encoding="utf8") as fd: + print("settings\tWER", file=fd) + for key, val in test_set_wers: + print(f"{key}\t{val}", file=fd) + + s = f"\nFor {test_set_name}, WER of different settings are:\n" + note = f"\tbest for {test_set_name}" + for key, val in test_set_wers: + s += f"{key}\t{val}{note}\n" + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + # enable AudioCache + set_caching_enabled(True) # lhotse + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"_chunk-{params.chunk_size}" + params.suffix += f"_left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"_beam-{params.beam}" + params.suffix += f"_max-contexts-{params.max_contexts}" + params.suffix += f"_max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + if params.label: + params.suffix += f"-{params.label}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + libritts = LibriTTSAsrDataModule(args) + + test_clean_cuts = libritts.test_clean_cuts().map(normalize_text) + test_other_cuts = libritts.test_other_cuts().map(normalize_text) + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_asr_output( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + if not params.skip_scoring: + save_wer_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/subsampling.py b/egs/libritts/ASR/zipformer/subsampling.py new file mode 120000 index 000000000..01ae9002c --- /dev/null +++ b/egs/libritts/ASR/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py new file mode 100755 index 000000000..5485eaf0a --- /dev/null +++ b/egs/libritts/ASR/zipformer/train.py @@ -0,0 +1,1527 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# Copyright 2024 The Chinese Univ. of HK (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --full-libri 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` + - ctc loss & attention decoder loss, no transducer loss, + with `--use-transducer False --use-ctc True --use-attention-decoder True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriTTSAsrDataModule +from attention_decoder import AttentionDecoderModel +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.err import raise_grad_scale_is_too_small_error +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--use-bf16", + type=str2bool, + default=False, + help="Whether to use bf16 in AMP.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, + "label_smoothing": 0.1, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def normalize_text(c: Cut): + def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + text = remove_punc_to_upper(c.supervisions[0].text) + c.supervisions[0].text = text + return c + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + num_heads=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + memory_dim=max(_to_int_tuple(params.encoder_dim)), + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + attention_decoder=attention_decoder, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except Exception as e: + logging.info(f"Caught exception: {e}.") + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_autocast: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise_grad_scale_is_too_small_error(cur_grad_scale) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_autocast: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, + params.attention_decoder_loss_scale, + ) + + if params.use_bf16: # amp + bf16 + assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!" + assert not params.use_fp16, "You can only use either fp16 or bf16" + params.dtype = torch.bfloat16 + params.use_autocast = True + elif params.use_fp16: # amp + fp16 + params.dtype = torch.float16 + params.use_autocast = True + else: # fp32 + params.dtype = torch.float32 + params.use_autocast = False + + logging.info(f"Using dtype={params.dtype}") + logging.info(f"Use AMP={params.use_autocast}") + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + libritts = LibriTTSAsrDataModule(args) + + if params.full_libri: + train_cuts = libritts.train_all_shuf_cuts() + + # previously we used the following code to load all training cuts, + # strictly speaking, shuffled training cuts should be used instead, + # but we leave the code here to demonstrate that there is an option + # like this to combine multiple cutsets + + # train_cuts = libritts.train_clean_100_cuts() + # train_cuts += libritts.train_clean_360_cuts() + # train_cuts += libritts.train_other_500_cuts() + else: + train_cuts = libritts.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.map(normalize_text) + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = libritts.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = libritts.dev_clean_cuts().map(normalize_text) + valid_cuts += libritts.dev_other_cuts().map(normalize_text) + valid_dl = libritts.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype + ): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriTTSAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libritts/ASR/zipformer/zipformer.py b/egs/libritts/ASR/zipformer/zipformer.py new file mode 120000 index 000000000..23011dda7 --- /dev/null +++ b/egs/libritts/ASR/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/egs/libritts/CODEC/encodec/base_discriminators.py b/egs/libritts/CODEC/encodec/base_discriminators.py new file mode 100644 index 000000000..7bc035554 --- /dev/null +++ b/egs/libritts/CODEC/encodec/base_discriminators.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from einops import rearrange +from modules.conv import NormConv1d, NormConv2d + + +def get_padding(kernel_size, dilation=1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def get_2d_padding(kernel_size: Tuple[int, int], dilation: Tuple[int, int] = (1, 1)): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) + + +class DiscriminatorP(nn.Module): + def __init__( + self, + period, + kernel_size=5, + stride=3, + activation: str = "LeakyReLU", + activation_params: dict = {"negative_slope": 0.2}, + ): + super(DiscriminatorP, self).__init__() + + self.period = period + self.activation = getattr(torch.nn, activation)(**activation_params) + self.convs = nn.ModuleList( + [ + NormConv2d( + 1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0) + ), + NormConv2d( + 32, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ), + NormConv2d( + 32, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ), + NormConv2d( + 32, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ), + NormConv2d(32, 32, (kernel_size, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = NormConv2d(32, 1, (3, 1), 1, padding=(1, 0)) + + def forward(self, x): + fmap = [] + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = self.activation(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(nn.Module): + def __init__( + self, + activation: str = "LeakyReLU", + activation_params: dict = {"negative_slope": 0.2}, + ): + super(DiscriminatorS, self).__init__() + self.activation = getattr(torch.nn, activation)(**activation_params) + self.convs = nn.ModuleList( + [ + NormConv1d(1, 32, 15, 1, padding=7), + NormConv1d(32, 32, 41, 2, groups=4, padding=20), + NormConv1d(32, 32, 41, 2, groups=16, padding=20), + NormConv1d(32, 32, 41, 4, groups=16, padding=20), + NormConv1d(32, 32, 41, 4, groups=16, padding=20), + NormConv1d(32, 32, 41, 1, groups=16, padding=20), + NormConv1d(32, 32, 5, 1, padding=2), + ] + ) + self.conv_post = NormConv1d(32, 1, 3, 1, padding=1) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = self.activation(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + return x, fmap + + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_fft (int): Size of FFT for each scale. Default: 1024 + hop_length (int): Length of hop between STFT windows for each scale. Default: 256 + kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` + stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` + dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` + win_length (int): Window size for each scale. Default: 1024 + normalized (bool): Whether to normalize by magnitude after stft. Default: True + norm (str): Normalization method. Default: `'weight_norm'` + activation (str): Activation function. Default: `'LeakyReLU'` + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. Default: 1 + """ + + def __init__( + self, + n_filters: int, + in_channels: int = 1, + out_channels: int = 1, + n_fft: int = 1024, + hop_length: int = 256, + win_length: int = 1024, + max_filters: int = 1024, + filters_scale: int = 1, + kernel_size: Tuple[int, int] = (3, 9), + dilations: List[int] = [1, 2, 4], + stride: Tuple[int, int] = (1, 2), + normalized: bool = True, + norm: str = "weight_norm", + activation: str = "LeakyReLU", + activation_params: dict = {"negative_slope": 0.2}, + ): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = n_filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window_fn=torch.hann_window, + normalized=self.normalized, + center=False, + pad_mode=None, + power=None, + ) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d( + spec_channels, + self.filters, + kernel_size=kernel_size, + padding=get_2d_padding(kernel_size), + ) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=(dilation, 1), + padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm, + ) + ) + in_chs = out_chs + out_chs = min( + (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm, + ) + ) + self.conv_post = NormConv2d( + out_chs, + self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm, + ) + + def forward(self, x: torch.Tensor): + fmap = [] + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + z = torch.cat([z.real, z.imag], dim=1) + z = rearrange(z, "b c w t -> b c t w") + for i, layer in enumerate(self.convs): + z = layer(z) + z = self.activation(z) + fmap.append(z) + z = self.conv_post(z) + return z, fmap diff --git a/egs/libritts/CODEC/encodec/binary.py b/egs/libritts/CODEC/encodec/binary.py new file mode 100644 index 000000000..003bcfaf5 --- /dev/null +++ b/egs/libritts/CODEC/encodec/binary.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" + +import io +import json +import struct +from typing import IO, Any, List, Optional + +# format is `ECDC` magic code, followed by the header size as uint32. +# Then an uint8 indicates the protocol version (0.) +# The header is then provided as json and should contain all required +# informations for decoding. A raw stream of bytes is then provided +# and should be interpretable using the json header. +_encodec_header_struct = struct.Struct("!4sBI") +_ENCODEC_MAGIC = b"ECDC" + + +def write_ecdc_header(fo: IO[bytes], metadata: Any): + meta_dumped = json.dumps(metadata).encode("utf-8") + version = 0 + header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, len(meta_dumped)) + fo.write(header) + fo.write(meta_dumped) + fo.flush() + + +def _read_exactly(fo: IO[bytes], size: int) -> bytes: + buf = b"" + while len(buf) < size: + new_buf = fo.read(size) + if not new_buf: + raise EOFError( + "Impossible to read enough data from the stream, " + f"{size} bytes remaining." + ) + buf += new_buf + size -= len(new_buf) + return buf + + +def read_ecdc_header(fo: IO[bytes]): + header_bytes = _read_exactly(fo, _encodec_header_struct.size) + magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) + if magic != _ENCODEC_MAGIC: + raise ValueError("File is not in ECDC format.") + if version != 0: + raise ValueError("Version not supported.") + meta_bytes = _read_exactly(fo, meta_size) + return json.loads(meta_bytes.decode("utf-8")) + + +class BitPacker: + """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. + Note that for some bandwidth (1.5, 3), the codebook representation + will not cover an integer number of bytes. + + Args: + bits (int): number of bits per value that will be pushed. + fo (IO[bytes]): file-object to push the bytes to. + """ + + def __init__(self, bits: int, fo: IO[bytes]): + self._current_value = 0 + self._current_bits = 0 + self.bits = bits + self.fo = fo + + def push(self, value: int): + """Push a new value to the stream. This will immediately + write as many uint8 as possible to the underlying file-object.""" + self._current_value += value << self._current_bits + self._current_bits += self.bits + while self._current_bits >= 8: + lower_8bits = self._current_value & 0xFF + self._current_bits -= 8 + self._current_value >>= 8 + self.fo.write(bytes([lower_8bits])) + + def flush(self): + """Flushes the remaining partial uint8, call this at the end + of the stream to encode.""" + if self._current_bits: + self.fo.write(bytes([self._current_value])) + self._current_value = 0 + self._current_bits = 0 + self.fo.flush() + + +class BitUnpacker: + """BitUnpacker does the opposite of `BitPacker`. + + Args: + bits (int): number of bits of the values to decode. + fo (IO[bytes]): file-object to push the bytes to. + """ + + def __init__(self, bits: int, fo: IO[bytes]): + self.bits = bits + self.fo = fo + self._mask = (1 << bits) - 1 + self._current_value = 0 + self._current_bits = 0 + + def pull(self) -> Optional[int]: + """ + Pull a single value from the stream, potentially reading some + extra bytes from the underlying file-object. + Returns `None` when reaching the end of the stream. + """ + while self._current_bits < self.bits: + buf = self.fo.read(1) + if not buf: + return None + character = buf[0] + self._current_value += character << self._current_bits + self._current_bits += 8 + + out = self._current_value & self._mask + self._current_value >>= self.bits + self._current_bits -= self.bits + return out + + +def test(): + import torch + + torch.manual_seed(1234) + for rep in range(4): + length: int = torch.randint(10, 2_000, (1,)).item() + bits: int = torch.randint(1, 16, (1,)).item() + tokens: List[int] = torch.randint(2**bits, (length,)).tolist() + rebuilt: List[int] = [] + buf = io.BytesIO() + packer = BitPacker(bits, buf) + for token in tokens: + packer.push(token) + packer.flush() + buf.seek(0) + unpacker = BitUnpacker(bits, buf) + while True: + value = unpacker.pull() + if value is None: + break + rebuilt.append(value) + assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) + # The flushing mechanism might lead to "ghost" values at the end of the stream. + assert len(rebuilt) <= len(tokens) + 8 // bits, ( + len(rebuilt), + len(tokens), + bits, + ) + for idx, (a, b) in enumerate(zip(tokens, rebuilt)): + assert a == b, (idx, a, b) + + +if __name__ == "__main__": + test() diff --git a/egs/libritts/CODEC/encodec/codec_datamodule.py b/egs/libritts/CODEC/encodec/codec_datamodule.py new file mode 100644 index 000000000..e77a255e5 --- /dev/null +++ b/egs/libritts/CODEC/encodec/codec_datamodule.py @@ -0,0 +1,336 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LibriTTSCodecDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="Codec data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="""When enabled, use the entire LibriTTS training set. + Otherwise, use the clean-100 subset.""", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=False, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + world_size=world_size, + rank=rank, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders( + self, + cuts_valid: CutSet, + world_size: Optional[int] = None, + rank: Optional[int] = None, + ) -> DataLoader: + logging.info("About to create dev dataset") + + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=False, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + world_size=world_size, + rank=rank, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=1, + drop_last=False, + persistent_workers=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=False, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-clean-360.jsonl.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-other-500.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) diff --git a/egs/libritts/CODEC/encodec/discriminators.py b/egs/libritts/CODEC/encodec/discriminators.py new file mode 100644 index 000000000..e6b7f0929 --- /dev/null +++ b/egs/libritts/CODEC/encodec/discriminators.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +import torch.nn as nn +from base_discriminators import DiscriminatorP, DiscriminatorS, DiscriminatorSTFT +from torch.nn import AvgPool1d + + +class MultiPeriodDiscriminator(nn.Module): + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class MultiScaleSTFTDiscriminator(nn.Module): + """Multi-Scale STFT (MS-STFT) discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_ffts (Sequence[int]): Size of FFT for each scale + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale + win_lengths (Sequence[int]): Window size for each scale + **kwargs: additional args for STFTDiscriminator + """ + + def __init__( + self, + n_filters: int, + in_channels: int = 1, + out_channels: int = 1, + n_ffts: List[int] = [1024, 2048, 512, 256, 128], + hop_lengths: List[int] = [256, 512, 128, 64, 32], + win_lengths: List[int] = [1024, 2048, 512, 256, 128], + **kwargs + ): + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.discriminators = nn.ModuleList( + [ + DiscriminatorSTFT( + n_filters, + in_channels=in_channels, + out_channels=out_channels, + n_fft=n_ffts[i], + win_length=win_lengths[i], + hop_length=hop_lengths[i], + **kwargs + ) + for i in range(len(n_ffts)) + ] + ) + self.num_discriminators = len(self.discriminators) + + def forward(self, x: torch.Tensor): + logits = [] + fmaps = [] + for disc in self.discriminators: + logit, fmap = disc(x) + logits.append(logit) + fmaps.append(fmap) + return logits, fmaps diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py new file mode 100644 index 000000000..f21d494b6 --- /dev/null +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 +# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random +from typing import List, Optional + +import numpy as np +import torch +from loss import ( + DiscriminatorAdversarialLoss, + FeatureLoss, + GeneratorAdversarialLoss, + MelSpectrogramReconstructionLoss, + WavReconstructionLoss, +) +from torch import nn +from torch.cuda.amp import autocast + + +class Encodec(nn.Module): + def __init__( + self, + sampling_rate: int, + target_bandwidths: List[float], + params: dict, + encoder: nn.Module, + quantizer: nn.Module, + decoder: nn.Module, + multi_scale_discriminator: nn.Module, + multi_period_discriminator: Optional[nn.Module] = None, + multi_scale_stft_discriminator: Optional[nn.Module] = None, + cache_generator_outputs: bool = False, + ): + super(Encodec, self).__init__() + + self.params = params + + # setup the generator + self.sampling_rate = sampling_rate + self.encoder = encoder + self.quantizer = quantizer + self.decoder = decoder + + self.ratios = encoder.ratios + self.hop_length = np.prod(self.ratios) + self.frame_rate = math.ceil(self.sampling_rate / np.prod(self.ratios)) + self.target_bandwidths = target_bandwidths + + # discriminators + self.multi_scale_discriminator = multi_scale_discriminator + self.multi_period_discriminator = multi_period_discriminator + self.multi_scale_stft_discriminator = multi_scale_stft_discriminator + + # cache + self.cache_generator_outputs = cache_generator_outputs + self._cache = None + + # construct loss functions + self.generator_adversarial_loss = GeneratorAdversarialLoss( + average_by_discriminators=True, loss_type="hinge" + ) + self.discriminator_adversarial_loss = DiscriminatorAdversarialLoss( + average_by_discriminators=True, loss_type="hinge" + ) + self.feature_match_loss = FeatureLoss() + self.wav_reconstruction_loss = WavReconstructionLoss() + self.mel_reconstruction_loss = MelSpectrogramReconstructionLoss( + sampling_rate=self.sampling_rate + ) + + def _forward_generator( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + return_sample: bool = False, + ): + """Perform generator forward. + + Args: + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + return_sample (bool): Return the generator output. + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + e = self.encoder(speech) + index = torch.tensor( + random.randint(0, len(self.target_bandwidths) - 1), + device=speech.device, + ) + if torch.distributed.is_initialized(): + torch.distributed.broadcast(index, src=0) + bw = self.target_bandwidths[index.item()] + quantized, codes, bandwidth, commit_loss = self.quantizer( + e, self.frame_rate, bw + ) + speech_hat = self.decoder(quantized) + else: + speech_hat = self._cache + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = speech_hat + + # calculate discriminator outputs + y_hat, fmap_hat = self.multi_scale_stft_discriminator(speech_hat.contiguous()) + with torch.no_grad(): + # do not store discriminator gradient in generator turn + y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) + + gen_period_adv_loss = torch.tensor(0.0) + feature_period_loss = torch.tensor(0.0) + if self.multi_period_discriminator is not None: + y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( + speech.contiguous(), + speech_hat.contiguous(), + ) + + gen_scale_adv_loss = torch.tensor(0.0) + feature_scale_loss = torch.tensor(0.0) + if self.multi_scale_discriminator is not None: + y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( + speech.contiguous(), + speech_hat.contiguous(), + ) + + # calculate losses + with autocast(enabled=False): + gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) + + if self.multi_period_discriminator is not None: + gen_period_adv_loss = self.generator_adversarial_loss(outputs=y_p_hat) + if self.multi_scale_discriminator is not None: + gen_scale_adv_loss = self.generator_adversarial_loss(outputs=y_s_hat) + + feature_stft_loss = self.feature_match_loss(feats=fmap, feats_hat=fmap_hat) + + if self.multi_period_discriminator is not None: + feature_period_loss = self.feature_match_loss( + feats=fmap_p, feats_hat=fmap_p_hat + ) + if self.multi_scale_discriminator is not None: + feature_scale_loss = self.feature_match_loss( + feats=fmap_s, feats_hat=fmap_s_hat + ) + + wav_reconstruction_loss = self.wav_reconstruction_loss( + x=speech, x_hat=speech_hat + ) + mel_reconstruction_loss = self.mel_reconstruction_loss( + x=speech, x_hat=speech_hat + ) + + stats = dict( + generator_wav_reconstruction_loss=wav_reconstruction_loss.item(), + generator_mel_reconstruction_loss=mel_reconstruction_loss.item(), + generator_feature_stft_loss=feature_stft_loss.item(), + generator_feature_period_loss=feature_period_loss.item(), + generator_feature_scale_loss=feature_scale_loss.item(), + generator_stft_adv_loss=gen_stft_adv_loss.item(), + generator_period_adv_loss=gen_period_adv_loss.item(), + generator_scale_adv_loss=gen_scale_adv_loss.item(), + generator_commit_loss=commit_loss.item(), + ) + + if return_sample: + stats["returned_sample"] = ( + speech_hat.cpu(), + speech.cpu(), + fmap_hat[0][0].data.cpu(), + fmap[0][0].data.cpu(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + return ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats, + ) + + def _forward_discriminator( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ): + """ + Args: + speech (Tensor): Speech waveform tensor (B, T_wav). + speech_lengths (Tensor): Speech length tensor (B,). + + Returns: + * loss (Tensor): Loss scalar tensor. + * stats (Dict[str, float]): Statistics to be monitored. + """ + # setup + speech = speech.unsqueeze(1) + + # calculate generator outputs + reuse_cache = True + if not self.cache_generator_outputs or self._cache is None: + reuse_cache = False + e = self.encoder(speech) + index = torch.tensor( + random.randint(0, len(self.target_bandwidths) - 1), + device=speech.device, + ) + if torch.distributed.is_initialized(): + torch.distributed.broadcast(index, src=0) + bw = self.target_bandwidths[index.item()] + quantized, codes, bandwidth, commit_loss = self.quantizer( + e, self.frame_rate, bw + ) + speech_hat = self.decoder(quantized) + else: + speech_hat = self._cache + + # store cache + if self.training and self.cache_generator_outputs and not reuse_cache: + self._cache = speech_hat + + # calculate discriminator outputs + y, fmap = self.multi_scale_stft_discriminator(speech.contiguous()) + y_hat, fmap_hat = self.multi_scale_stft_discriminator( + speech_hat.contiguous().detach() + ) + + disc_period_real_adv_loss = torch.tensor(0.0) + disc_period_fake_adv_loss = torch.tensor(0.0) + if self.multi_period_discriminator is not None: + y_p, y_p_hat, fmap_p, fmap_p_hat = self.multi_period_discriminator( + speech.contiguous(), + speech_hat.contiguous().detach(), + ) + + disc_scale_real_adv_loss = torch.tensor(0.0) + disc_scale_fake_adv_loss = torch.tensor(0.0) + if self.multi_scale_discriminator is not None: + y_s, y_s_hat, fmap_s, fmap_s_hat = self.multi_scale_discriminator( + speech.contiguous(), + speech_hat.contiguous().detach(), + ) + # calculate losses + with autocast(enabled=False): + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + ) = self.discriminator_adversarial_loss(outputs=y, outputs_hat=y_hat) + if self.multi_period_discriminator is not None: + ( + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + ) = self.discriminator_adversarial_loss( + outputs=y_p, outputs_hat=y_p_hat + ) + if self.multi_scale_discriminator is not None: + ( + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + ) = self.discriminator_adversarial_loss( + outputs=y_s, outputs_hat=y_s_hat + ) + + stats = dict( + discriminator_stft_real_adv_loss=disc_stft_real_adv_loss.item(), + discriminator_period_real_adv_loss=disc_period_real_adv_loss.item(), + discriminator_scale_real_adv_loss=disc_scale_real_adv_loss.item(), + discriminator_stft_fake_adv_loss=disc_stft_fake_adv_loss.item(), + discriminator_period_fake_adv_loss=disc_period_fake_adv_loss.item(), + discriminator_scale_fake_adv_loss=disc_scale_fake_adv_loss.item(), + ) + + # reset cache + if reuse_cache or not self.training: + self._cache = None + + return ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats, + ) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + return_sample: bool, + forward_generator: bool, + ): + if forward_generator: + return self._forward_generator( + speech=speech, + speech_lengths=speech_lengths, + return_sample=return_sample, + ) + else: + return self._forward_discriminator( + speech=speech, + speech_lengths=speech_lengths, + ) + + def encode(self, x, target_bw=None, st=None): + e = self.encoder(x) + if target_bw is None: + bw = self.target_bandwidths[-1] + else: + bw = target_bw + if st is None: + st = 0 + codes = self.quantizer.encode(e, self.frame_rate, bw, st) + return codes + + def decode(self, codes): + quantized = self.quantizer.decode(codes) + x_hat = self.decoder(quantized) + return x_hat + + def inference(self, x, target_bw=None, st=None): + # setup + x = x.unsqueeze(1) + + codes = self.encode(x, target_bw, st) + x_hat = self.decode(codes) + return codes, x_hat diff --git a/egs/libritts/CODEC/encodec/infer.py b/egs/libritts/CODEC/encodec/infer.py new file mode 100755 index 000000000..3c6ea15f9 --- /dev/null +++ b/egs/libritts/CODEC/encodec/infer.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# +# Copyright 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./codec/infer.py \ + --epoch 300 \ + --exp-dir ./codec/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from statistics import mean +from typing import List, Tuple + +import numpy as np +import torch +import torchaudio +from codec_datamodule import LibriTTSCodecDataModule +from pesq import pesq +from pystoi import stoi +from scipy import signal +from torch import nn +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="encodec/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--target-bw", + type=float, + default=24, + help="The target bandwidth for the generator", + ) + + return parser + + +# implementation from https://github.com/yangdongchao/AcademiCodec/blob/master/academicodec/models/encodec/test.py +def remove_encodec_weight_norm(model) -> None: + from modules import SConv1d + from modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +def compute_pesq(ref_wav: np.ndarray, gen_wav: np.ndarray) -> float: + """Compute PESQ score between reference and generated audio.""" + DEFAULT_SAMPLING_RATE = 16000 + ref = signal.resample(ref_wav, DEFAULT_SAMPLING_RATE) + deg = signal.resample(gen_wav, DEFAULT_SAMPLING_RATE) + return pesq(fs=DEFAULT_SAMPLING_RATE, ref=ref, deg=deg, mode="wb") + + +def compute_stoi(ref_wav: np.ndarray, gen_wav: np.ndarray, sampling_rate: int) -> float: + """Compute STOI score between reference and generated audio.""" + return stoi(x=ref_wav, y=gen_wav, fs_sig=sampling_rate, extended=False) + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + subset: str, + params: AttributeDict, + model: nn.Module, +) -> Tuple[float, float]: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + subset: + The name of the subset. + params: + It is returned by :func:`get_params`. + model: + The neural model. + + Returns: + The average PESQ and STOI scores. + """ + + # Background worker save audios to disk. + def _save_worker( + subset: str, + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), + audio[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_recon.wav"), + audio_pred[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + pesq_wb_scores = [] + stoi_scores = [] + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["audio"]) + + audios = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + codes, audio_hats = model.inference( + audios.to(device), target_bw=params.target_bw + ) + audio_hats = audio_hats.squeeze(1).cpu() + + for cut_id, audio, audio_hat, audio_len in zip( + cut_ids, audios, audio_hats, audio_lens + ): + try: + pesq_wb = compute_pesq( + ref_wav=audio[:audio_len].numpy(), + gen_wav=audio_hat[:audio_len].numpy(), + ) + pesq_wb_scores.append(pesq_wb) + except Exception as e: + logging.error(f"Error while computing PESQ for cut {cut_id}: {e}") + + stoi_score = compute_stoi( + ref_wav=audio[:audio_len].numpy(), + gen_wav=audio_hat[:audio_len].numpy(), + sampling_rate=params.sampling_rate, + ) + stoi_scores.append(stoi_score) + + futures.append( + executor.submit( + _save_worker, + subset, + batch_size, + cut_ids, + audios, + audio_hats, + audio_lens, + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + # return results + for f in futures: + f.result() + return mean(pesq_wb_scores), mean(stoi_scores) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriTTSCodecDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + # we need cut ids to display results of both constructed and ground-truth audio + args.return_cuts = True + libritts = LibriTTSCodecDataModule(args) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + remove_encodec_weight_norm(model) + + model.to(device) + model.eval() + + encoder = model.encoder + decoder = model.decoder + quantizer = model.quantizer + multi_scale_discriminator = model.multi_scale_discriminator + multi_period_discriminator = model.multi_period_discriminator + multi_scale_stft_discriminator = model.multi_scale_stft_discriminator + + num_param_e = sum([p.numel() for p in encoder.parameters()]) + logging.info(f"Number of parameters in encoder: {num_param_e}") + num_param_d = sum([p.numel() for p in decoder.parameters()]) + logging.info(f"Number of parameters in decoder: {num_param_d}") + num_param_q = sum([p.numel() for p in quantizer.parameters()]) + logging.info(f"Number of parameters in quantizer: {num_param_q}") + num_param_ds = ( + sum([p.numel() for p in multi_scale_discriminator.parameters()]) + if multi_scale_discriminator is not None + else 0 + ) + logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") + num_param_dp = ( + sum([p.numel() for p in multi_period_discriminator.parameters()]) + if multi_period_discriminator is not None + else 0 + ) + logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") + num_param_dstft = sum( + [p.numel() for p in multi_scale_stft_discriminator.parameters()] + ) + logging.info( + f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}" + ) + logging.info( + f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}" + ) + + test_clean_cuts = libritts.test_clean_cuts() + test_clean = libritts.test_dataloaders(test_clean_cuts) + + test_other_cuts = libritts.test_other_cuts() + test_other = libritts.test_dataloaders(test_other_cuts) + + dev_clean_cuts = libritts.dev_clean_cuts() + dev_clean = libritts.valid_dataloaders(dev_clean_cuts) + + dev_other_cuts = libritts.dev_other_cuts() + dev_other = libritts.valid_dataloaders(dev_other_cuts) + + infer_sets = { + "test-clean": test_clean, + "test-other": test_other, + "dev-clean": dev_clean, + "dev-other": dev_other, + } + + for subset, dl in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + pesq_wb, stoi = infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + ) + logging.info(f"{subset}: PESQ-WB: {pesq_wb:.4f}, STOI: {stoi:.4f}") + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/CODEC/encodec/loss.py b/egs/libritts/CODEC/encodec/loss.py new file mode 100644 index 000000000..9cf1d42d2 --- /dev/null +++ b/egs/libritts/CODEC/encodec/loss.py @@ -0,0 +1,321 @@ +# Modified from egs/ljspeech/TTS/vits/loss.py by: Zengrui JIN (Tsinghua University) +# original implementation is from https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/hifigan/loss.py + +# Copyright 2021 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encodec-related loss modules. + +This code is modified from https://github.com/kan-bayashi/ParallelWaveGAN. + +""" + +from typing import List, Tuple, Union + +import torch +import torch.nn.functional as F +from torchaudio.transforms import MelSpectrogram + + +class GeneratorAdversarialLoss(torch.nn.Module): + """Generator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "hinge", + ): + """Initialize GeneratorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.criterion = self._mse_loss + else: + self.criterion = self._hinge_loss + + def forward( + self, + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> torch.Tensor: + """Calcualate generator adversarial loss. + + Args: + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs.. + + Returns: + Tensor: Generator adversarial loss value. + + """ + adv_loss = 0.0 + if isinstance(outputs, (tuple, list)): + for i, outputs_ in enumerate(outputs): + if isinstance(outputs_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_ = outputs_[-1] + adv_loss += self.criterion(outputs_) + if self.average_by_discriminators: + adv_loss /= i + 1 + else: + for i, outputs_ in enumerate(outputs): + adv_loss += self.criterion(outputs_) + adv_loss /= i + 1 + return adv_loss + + def _mse_loss(self, x): + return F.mse_loss(x, x.new_ones(x.size())) + + def _hinge_loss(self, x): + return F.relu(1 - x).mean() + + +class DiscriminatorAdversarialLoss(torch.nn.Module): + """Discriminator adversarial loss module.""" + + def __init__( + self, + average_by_discriminators: bool = True, + loss_type: str = "hinge", + ): + """Initialize DiscriminatorAversarialLoss module. + + Args: + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + loss_type (str): Loss type, "mse" or "hinge". + + """ + super().__init__() + self.average_by_discriminators = average_by_discriminators + assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." + if loss_type == "mse": + self.fake_criterion = self._mse_fake_loss + self.real_criterion = self._mse_real_loss + else: + self.fake_criterion = self._hinge_fake_loss + self.real_criterion = self._hinge_real_loss + + def forward( + self, + outputs_hat: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + outputs: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calcualate discriminator adversarial loss. + + Args: + outputs_hat (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from generator. + outputs (Union[List[List[Tensor]], List[Tensor], Tensor]): Discriminator + outputs, list of discriminator outputs, or list of list of discriminator + outputs calculated from groundtruth. + + Returns: + Tensor: Discriminator real loss value. + Tensor: Discriminator fake loss value. + + """ + real_loss = 0.0 + fake_loss = 0.0 + if isinstance(outputs, (tuple, list)): + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + if isinstance(outputs_hat_, (tuple, list)): + # NOTE(kan-bayashi): case including feature maps + outputs_hat_ = outputs_hat_[-1] + outputs_ = outputs_[-1] + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + if self.average_by_discriminators: + fake_loss /= i + 1 + real_loss /= i + 1 + else: + for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): + real_loss += self.real_criterion(outputs_) + fake_loss += self.fake_criterion(outputs_hat_) + fake_loss /= i + 1 + real_loss /= i + 1 + + return real_loss, fake_loss + + def _mse_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_ones(x.size())) + + def _mse_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.mse_loss(x, x.new_zeros(x.size())) + + def _hinge_real_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.relu(torch.ones_like(x) - x).mean() + + def _hinge_fake_loss(self, x: torch.Tensor) -> torch.Tensor: + return F.relu(torch.ones_like(x) + x).mean() + + +class FeatureLoss(torch.nn.Module): + """Feature loss module.""" + + def __init__( + self, + average_by_layers: bool = True, + average_by_discriminators: bool = True, + include_final_outputs: bool = True, + ): + """Initialize FeatureMatchLoss module. + + Args: + average_by_layers (bool): Whether to average the loss by the number + of layers. + average_by_discriminators (bool): Whether to average the loss by + the number of discriminators. + include_final_outputs (bool): Whether to include the final output of + each discriminator for loss calculation. + + """ + super().__init__() + self.average_by_layers = average_by_layers + self.average_by_discriminators = average_by_discriminators + self.include_final_outputs = include_final_outputs + + def forward( + self, + feats_hat: Union[List[List[torch.Tensor]], List[torch.Tensor]], + feats: Union[List[List[torch.Tensor]], List[torch.Tensor]], + ) -> torch.Tensor: + """Calculate feature matching loss. + + Args: + feats_hat (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from generator's outputs. + feats (Union[List[List[Tensor]], List[Tensor]]): List of list of + discriminator outputs or list of discriminator outputs calcuated + from groundtruth.. + + Returns: + Tensor: Feature matching loss value. + + """ + feat_match_loss = 0.0 + for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): + feat_match_loss_ = 0.0 + if not self.include_final_outputs: + feats_hat_ = feats_hat_[:-1] + feats_ = feats_[:-1] + for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): + feat_match_loss_ += ( + F.l1_loss(feat_hat_, feat_.detach()) / (feat_.detach().abs().mean()) + ).mean() + if self.average_by_layers: + feat_match_loss_ /= j + 1 + feat_match_loss += feat_match_loss_ + if self.average_by_discriminators: + feat_match_loss /= i + 1 + + return feat_match_loss + + +class MelSpectrogramReconstructionLoss(torch.nn.Module): + """Mel Spec Reconstruction loss.""" + + def __init__( + self, + sampling_rate: int = 22050, + n_mels: int = 64, + use_fft_mag: bool = True, + return_mel: bool = False, + ): + super().__init__() + self.wav_to_specs = [] + for i in range(5, 12): + s = 2**i + self.wav_to_specs.append( + MelSpectrogram( + sample_rate=sampling_rate, + n_fft=max(s, 512), + win_length=s, + hop_length=s // 4, + n_mels=n_mels, + ) + ) + self.return_mel = return_mel + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Calculate Mel-spectrogram loss. + + Args: + x_hat (Tensor): Generated waveform tensor (B, 1, T). + x (Tensor): Groundtruth waveform tensor (B, 1, T). + spec (Optional[Tensor]): Groundtruth linear amplitude spectrum tensor + (B, T, n_fft // 2 + 1). if provided, use it instead of groundtruth + waveform. + + Returns: + Tensor: Mel-spectrogram loss value. + + """ + mel_loss = 0.0 + + for i, wav_to_spec in enumerate(self.wav_to_specs): + s = 2 ** (i + 5) + wav_to_spec.to(x.device) + + mel_hat = wav_to_spec(x_hat.squeeze(1)) + mel = wav_to_spec(x.squeeze(1)) + + mel_loss += ( + F.l1_loss(mel_hat, mel, reduce=True, reduction="mean") + + ( + ( + (torch.log(mel.abs() + 1e-7) - torch.log(mel_hat.abs() + 1e-7)) + ** 2 + ).mean(dim=-2) + ** 0.5 + ).mean() + ) + + # mel_hat = self.wav_to_spec(x_hat.squeeze(1)) + # mel = self.wav_to_spec(x.squeeze(1)) + # mel_loss = F.l1_loss(mel_hat, mel) + F.mse_loss(mel_hat, mel) + + if self.return_mel: + return mel_loss, (mel_hat, mel) + + return mel_loss + + +class WavReconstructionLoss(torch.nn.Module): + """Wav Reconstruction loss.""" + + def __init__(self): + super().__init__() + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + ) -> torch.Tensor: + """Calculate wav loss. + + Args: + x_hat (Tensor): Generated waveform tensor (B, 1, T). + x (Tensor): Groundtruth waveform tensor (B, 1, T). + + Returns: + Tensor: Wav loss value. + + """ + wav_loss = F.l1_loss(x, x_hat) + + return wav_loss diff --git a/egs/libritts/CODEC/encodec/modules/__init__.py b/egs/libritts/CODEC/encodec/modules/__init__.py new file mode 100644 index 000000000..b903a28b0 --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""Torch modules.""" +# flake8: noqa +from .conv import ( + NormConv1d, + NormConv2d, + NormConvTranspose1d, + NormConvTranspose2d, + SConv1d, + SConvTranspose1d, + pad1d, + unpad1d, +) +from .lstm import SLSTM +from .seanet import SEANetDecoder, SEANetEncoder +from .transformer import StreamingTransformerEncoder diff --git a/egs/libritts/CODEC/encodec/modules/conv.py b/egs/libritts/CODEC/encodec/modules/conv.py new file mode 100644 index 000000000..a70a5c67f --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/conv.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""Convolutional layers wrappers and utilities.""" +import logging +import math +from typing import Any, Dict, Tuple + +from torch import Tensor, nn +from torch.nn import functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .norm import ConvLayerNorm + +CONV_NORMALIZATIONS = frozenset( + [ + "none", + "weight_norm", + "spectral_norm", + "time_layer_norm", + "layer_norm", + "time_group_norm", + ] +) + + +def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == "weight_norm": + return weight_norm(module) + elif norm == "spectral_norm": + return spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module( + module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs +) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == "layer_norm": + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == "time_group_norm": + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d( + x: Tensor, kernel_size: int, stride: int, padding_total: int = 0 +) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d( + x: Tensor, + paddings: Tuple[int, int], + mode: str = "zero", + value: float = 0.0, +): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: Tensor, paddings: Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left:end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConv2d(nn.Module): + """Wrapper around Conv2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose1d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class NormConvTranspose2d(nn.Module): + """Wrapper around ConvTranspose2d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + **kwargs, + ): + super().__init__() + self.convtr = apply_parametrization_norm( + nn.ConvTranspose2d(*args, **kwargs), norm + ) + self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class SConv1d(nn.Module): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: Dict[str, Any] = {}, + pad_mode: str = "reflect", + ): + super().__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + logging.warning( + "SConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = NormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.pad_mode = pad_mode + + def forward(self, x): + B, C, T = x.shape + kernel_size = self.conv.conv.kernel_size[0] + stride = self.conv.conv.stride[0] + dilation = self.conv.conv.dilation[0] + padding_total = (kernel_size - 1) * dilation - (stride - 1) + extra_padding = get_extra_padding_for_conv1d( + x, kernel_size, stride, padding_total + ) + if self.causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d( + x, (padding_left, padding_right + extra_padding), mode=self.pad_mode + ) + return self.conv(x) + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: Dict[str, Any] = {}, + ): + super().__init__() + self.convtr = NormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 + + def forward(self, x): + kernel_size = self.convtr.convtr.kernel_size[0] + stride = self.convtr.convtr.stride[0] + padding_total = kernel_size - stride + + y = self.convtr(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + padding_right = math.ceil(padding_total * self.trim_right_ratio) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y diff --git a/egs/libritts/CODEC/encodec/modules/lstm.py b/egs/libritts/CODEC/encodec/modules/lstm.py new file mode 100644 index 000000000..5307552c0 --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/lstm.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""LSTM layers module.""" +from torch import nn + + +class SLSTM(nn.Module): + """ + LSTM without worrying about the hidden state, nor the layout of the data. + Expects input as convolutional layout. + """ + + def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension, num_layers) + + def forward(self, x): + x = x.permute(2, 0, 1) + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = y.permute(1, 2, 0) + return y diff --git a/egs/libritts/CODEC/encodec/modules/norm.py b/egs/libritts/CODEC/encodec/modules/norm.py new file mode 100644 index 000000000..3002b3a26 --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/norm.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""Normalization modules.""" + +from typing import List, Union + +import einops +import torch +from torch import nn + + +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + + def __init__(self, normalized_shape: Union[int, List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = einops.rearrange(x, "b ... t -> b t ...") + x = super().forward(x) + x = einops.rearrange(x, "b t ... -> b ... t") + return diff --git a/egs/libritts/CODEC/encodec/modules/seanet.py b/egs/libritts/CODEC/encodec/modules/seanet.py new file mode 100644 index 000000000..76999b298 --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/seanet.py @@ -0,0 +1,368 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""Encodec SEANet-based encoder and decoder implementation.""" + +from typing import Any, Dict, List, Optional + +import numpy as np +import torch.nn as nn +from modules import SLSTM, SConv1d, SConvTranspose1d + + +class SEANetResnetBlock(nn.Module): + """Residual block from SEANet model. + Args: + dim (int): Dimension of the input/output + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3) + true_skip (bool): Whether to use true skip connection or a simple convolution as the skip connection. + """ + + def __init__( + self, + dim: int, + kernel_sizes: List[int] = [3, 1], + dilations: List[int] = [1, 1], + activation: str = "ELU", + activation_params: Dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: Dict[str, Any] = {}, + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): + super().__init__() + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + SConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + self.block = nn.Sequential(*block) + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = SConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class SEANetEncoder(nn.Module): + """SEANet encoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of + upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here + that must match the decoder order + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + norm: str = "weight_norm", + norm_params: Dict[str, Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + ): + super().__init__() + self.channels = channels + self.dimension = dimension + self.n_filters = n_filters + self.ratios = list(reversed(ratios)) + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) # 计算乘积 + + act = getattr(nn, activation) + mult = 1 + model: List[nn.Module] = [ + SConv1d( + channels, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + # Downsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + norm=norm, + norm_params=norm_params, + activation=activation, + activation_params=activation_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + # Add downsampling layers + model += [ + act(**activation_params), + SConv1d( + mult * n_filters, + mult * n_filters * 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + mult *= 2 + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + model += [ + act(**activation_params), + SConv1d( + mult * n_filters, + dimension, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + return self.model(x) + + +class SEANetDecoder(nn.Module): + """SEANet decoder. + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + lstm (int): Number of LSTM layers at the end of the encoder. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 1, + ratios: List[int] = [8, 5, 4, 2], + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, + final_activation: Optional[str] = None, + final_activation_params: Optional[dict] = None, + norm: str = "weight_norm", + norm_params: Dict[str, Any] = {}, + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = False, + compress: int = 2, + lstm: int = 2, + trim_right_ratio: float = 1.0, + ): + super().__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + del ratios + self.n_residual_layers = n_residual_layers + self.hop_length = np.prod(self.ratios) + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model: List[nn.Module] = [ + SConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + + if lstm: + model += [SLSTM(mult * n_filters, num_layers=lstm)] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + # Add upsampling layers + model += [ + act(**activation_params), + SConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + SEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + SConv1d( + n_filters, + channels, + last_kernel_size, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + # Add optional final activation to decoder (eg. tanh) + if final_activation is not None: + final_act = getattr(nn, final_activation) + final_activation_params = final_activation_params or {} + model += [final_act(**final_activation_params)] + self.model = nn.Sequential(*model) + + def forward(self, z): + y = self.model(z) + return y + + +def test(): + import torch + + encoder = SEANetEncoder() + decoder = SEANetDecoder() + x = torch.randn(1, 1, 24000) + z = encoder(x) + print("z ", z.shape) + assert 1 == 2 + assert list(z.shape) == [1, 128, 75], z.shape + y = decoder(z) + assert y.shape == x.shape, (x.shape, y.shape) + + +if __name__ == "__main__": + test() diff --git a/egs/libritts/CODEC/encodec/modules/transformer.py b/egs/libritts/CODEC/encodec/modules/transformer.py new file mode 100644 index 000000000..1768d88f9 --- /dev/null +++ b/egs/libritts/CODEC/encodec/modules/transformer.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""A streamable transformer.""" +import typing as tp +from typing import Any, List, Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +def create_sin_embedding(positions: Tensor, dim: int, max_period: float = 10000): + """Create time embedding for the given positions, target dimension `dim`.""" + # We aim for BTC format + assert dim % 2 == 0 + half_dim = dim // 2 + adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) + phase = positions / (max_period ** (adim / (half_dim - 1))) + return torch.cat( + [ + torch.cos(phase), + torch.sin(phase), + ], + dim=-1, + ) + + +class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): + def forward(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore + if self.norm_first: + sa_input = self.norm1(x) + x = x + self._sa_block(sa_input, x_past, past_context) + x = x + self._ff_block(self.norm2(x)) + else: + sa_input = x + x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) + x = self.norm2(x + self._ff_block(x)) + + return x, sa_input + + # self-attention block + def _sa_block(self, x: Tensor, x_past: Tensor, past_context: int): # type: ignore + _, T, _ = x.shape + _, H, _ = x_past.shape + + queries = x + keys = torch.cat([x_past, x], dim=1) + values = keys + + queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) + keys_pos = torch.arange(T + H, device=x.device).view(1, -1) + delta = queries_pos - keys_pos + valid_access = (delta >= 0) & (delta <= past_context) + x = self.self_attn( + queries, keys, values, attn_mask=~valid_access, need_weights=False + )[0] + return self.dropout1(x) + + +class StreamingTransformerEncoder(nn.Module): + """TransformerEncoder with streaming support. + + Args: + dim (int): dimension of the data. + hidden_scale (int): intermediate dimension of FF module is this times the dimension. + num_heads (int): number of heads. + num_layers (int): number of layers. + max_period (float): maxium period of cosines in the positional embedding. + past_context (int or None): receptive field for the causal mask, infinite if None. + gelu (bool): if true uses GeLUs, otherwise use ReLUs. + norm_in (bool): normalize the input. + dropout (float): dropout probability. + **kwargs: See `nn.TransformerEncoderLayer`. + """ + + def __init__( + self, + dim, + hidden_scale: float = 4.0, + num_heads: int = 8, + num_layers: int = 5, + max_period: float = 10000, + past_context: int = 1000, + gelu: bool = True, + norm_in: bool = True, + dropout: float = 0.0, + **kwargs + ): + super().__init__() + assert dim % num_heads == 0 + hidden_dim = int(dim * hidden_scale) + + self.max_period = max_period + self.past_context = past_context + activation: Any = F.gelu if gelu else F.relu + + self.norm_in: nn.Module + if norm_in: + self.norm_in = nn.LayerNorm(dim) + else: + self.norm_in = nn.Identity() + + self.layers = nn.ModuleList() + for idx in range(num_layers): + self.layers.append( + StreamingTransformerEncoderLayer( + dim, + num_heads, + hidden_dim, + activation=activation, + batch_first=True, + dropout=dropout, + **kwargs + ) + ) + + def forward( + self, + x: Tensor, + states: Optional[List[Tensor]] = None, + offset: Union[int, Tensor] = 0, + ): + B, T, C = x.shape + if states is None: + states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))] + + positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset + pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) + + new_state: List[Tensor] = [] + x = self.norm_in(x) + x = x + pos_emb + + for layer_state, layer in zip(states, self.layers): + x, new_layer_state = layer(x, layer_state, self.past_context) + new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) + new_state.append(new_layer_state[:, -self.past_context :, :]) + return x, new_state, offset + T diff --git a/egs/libritts/CODEC/encodec/quantization/__init__.py b/egs/libritts/CODEC/encodec/quantization/__init__.py new file mode 100644 index 000000000..82d744f5f --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/egs/libritts/CODEC/encodec/quantization/ac.py b/egs/libritts/CODEC/encodec/quantization/ac.py new file mode 100644 index 000000000..8d8a770ca --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/ac.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""Arithmetic coder.""" +import io +import math +import random +from typing import IO, Any, List, Optional + +import torch +from torch import Tensor + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf( + pdf: Tensor, + total_range_bits: int, + roundoff: float = 1e-8, + min_range: int = 2, + check: bool = True, +) -> Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2**total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] + if ( + (quantized_cdf[1:] - quantized_cdf[:-1]) < min_range + ).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: List[Any] = [] + self._dbg2: List[Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2**self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int( + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + ) + effective_high = int( + math.floor(range_high * (self.delta / (2**self.total_range_bits))) + ) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, ( + effective_low, + effective_high, + range_low, + range_high, + ) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream.""" + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + + def __init__(self, fo: IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: List[Any] = [] + self._dbg2: List[Any] = [] + self._last: Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + self.current -= b1 << self.max_bit + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: Tensor) -> Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2**self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int( + math.ceil(range_low * (self.delta / (2**self.total_range_bits))) + ) + effective_high = int( + math.floor(range_high * (self.delta / (2**self.total_range_bits))) + ) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/egs/libritts/CODEC/encodec/quantization/core_vq.py b/egs/libritts/CODEC/encodec/quantization/core_vq.py new file mode 100644 index 000000000..4719e20f7 --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/core_vq.py @@ -0,0 +1,377 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Core vector quantization implementation.""" + +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import nn + +from .distrib import broadcast_tensors + + +def default(val: Any, d: Any) -> Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: Union[Callable[..., torch.Tensor], Any] = ( + uniform_init if not kmeans_init else torch.zeros + ) + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = ( + nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + ) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode( + self, x: torch.Tensor, n_q: Optional[int] = None, st: Optional[int] = None + ) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + st = st or 0 + for layer in self.layers[st:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/egs/libritts/CODEC/encodec/quantization/distrib.py b/egs/libritts/CODEC/encodec/quantization/distrib.py new file mode 100644 index 000000000..41ac7525f --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/distrib.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""Torch distributed utilities.""" +from typing import Dict, Iterable, List + +import torch +from torch import distributed as dist + + +def rank(): + if dist.is_initialized(): + return dist.get_rank() + else: + return 0 + + +def world_size(): + if dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=dist.ReduceOp.SUM): + if is_distributed(): + return dist.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, " + "at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + # src = int(rank()) # added code + handle = dist.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = dist.all_reduce( + buffer.data, op=dist.ReduceOp.SUM, async_op=True + ) + else: + handle = dist.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/egs/libritts/CODEC/encodec/quantization/vq.py b/egs/libritts/CODEC/encodec/quantization/vq.py new file mode 100644 index 000000000..8e59887a6 --- /dev/null +++ b/egs/libritts/CODEC/encodec/quantization/vq.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file at https://github.com/facebookresearch/encodec/blob/main/LICENSE +"""Residual vector quantizer implementation.""" +import math +from dataclasses import dataclass, field +from typing import Optional + +import torch +from torch import Tensor, nn + +from .core_vq import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: Tensor + codes: Tensor + bandwidth: Tensor # bandwidth in kb/s used, per batch item. + penalty: Optional[Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward( + self, x: Tensor, sample_rate: int, bandwidth: Optional[float] = None + ) -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (Tensor): Input tensor. + sample_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return quantized, codes, bw, torch.mean(commit_loss) + # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth( + self, sample_rate: int, bandwidth: Optional[float] = None + ) -> int: + """Return n_q based on specified target bandwidth.""" + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.0: + n_q = int(max(1, math.floor(bandwidth / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, sample_rate: int): + """Return bandwidth per quantizer for a given input sample rate.""" + return math.log2(self.bins) * sample_rate / 1000 + + def encode( + self, + x: Tensor, + sample_rate: int, + bandwidth: Optional[float] = None, + st: Optional[int] = None, + ) -> Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + st = st or 0 + codes = self.vq.encode(x, n_q=n_q, st=st) + return codes + + def decode(self, codes: Tensor) -> Tensor: + """Decode the given codes to the quantized representation.""" + quantized = self.vq.decode(codes) + return quantized diff --git a/egs/libritts/CODEC/encodec/scheduler.py b/egs/libritts/CODEC/encodec/scheduler.py new file mode 100644 index 000000000..00ef9882a --- /dev/null +++ b/egs/libritts/CODEC/encodec/scheduler.py @@ -0,0 +1,171 @@ +# original implementation is from https://github.com/ZhikangNiu/encodec-pytorch/blob/main/scheduler.py + +# Copyright 2024 Zhi-Kang Niu +# MIT License + +import math +from bisect import bisect_right + +from torch.optim.lr_scheduler import _LRScheduler + + +# It will be replaced with huggingface optimization +class WarmUpLR(_LRScheduler): + """warmup_training learning rate scheduler + Args: + optimizer: optimzier(e.g. SGD) + total_iters: totoal_iters of warmup phase + """ + + def __init__(self, optimizer, iter_per_epoch, warmup_epoch, last_epoch=-1): + + self.total_iters = iter_per_epoch * warmup_epoch + self.iter_per_epoch = iter_per_epoch + super().__init__(optimizer, last_epoch) + + def get_lr(self): + """we will use the first m batches, and set the learning + rate to base_lr * m / total_iters + """ + return [ + base_lr * self.last_epoch / (self.total_iters + 1e-8) + for base_lr in self.base_lrs + ] + + +class WarmupLrScheduler(_LRScheduler): + def __init__( + self, + optimizer, + warmup_iter=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.warmup_iter = warmup_iter + self.warmup_ratio = warmup_ratio + self.warmup = warmup + super(WarmupLrScheduler, self).__init__(optimizer, last_epoch) + + def get_lr(self): + ratio = self.get_lr_ratio() + lrs = [ratio * lr for lr in self.base_lrs] + return lrs + + def get_lr_ratio(self): + if self.last_epoch < self.warmup_iter: + ratio = self.get_warmup_ratio() + else: + ratio = self.get_main_ratio() + return ratio + + def get_main_ratio(self): + raise NotImplementedError + + def get_warmup_ratio(self): + assert self.warmup in ("linear", "exp") + alpha = self.last_epoch / self.warmup_iter + if self.warmup == "linear": + ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha + elif self.warmup == "exp": + ratio = self.warmup_ratio ** (1.0 - alpha) + return ratio + + +class WarmupPolyLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + power, + max_iter, + warmup_iter=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.power = power + self.max_iter = max_iter + super(WarmupPolyLrScheduler, self).__init__( + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_iter = self.last_epoch - self.warmup_iter + real_max_iter = self.max_iter - self.warmup_iter + alpha = real_iter / real_max_iter + ratio = (1 - alpha) ** self.power + return ratio + + +class WarmupExpLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + gamma, + interval=1, + warmup_iter=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.gamma = gamma + self.interval = interval + super(WarmupExpLrScheduler, self).__init__( + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_iter = self.last_epoch - self.warmup_iter + ratio = self.gamma ** (real_iter // self.interval) + return ratio + + +class WarmupCosineLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + max_iter, + eta_ratio=0, + warmup_iter=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.eta_ratio = eta_ratio + self.max_iter = max_iter + super(WarmupCosineLrScheduler, self).__init__( + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_iter = self.last_epoch - self.warmup_iter + real_max_iter = self.max_iter - self.warmup_iter + return ( + self.eta_ratio + + (1 - self.eta_ratio) + * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) + / 2 + ) + + +class WarmupStepLrScheduler(WarmupLrScheduler): + def __init__( + self, + optimizer, + milestones: list, + gamma=0.1, + warmup_iter=500, + warmup_ratio=5e-4, + warmup="exp", + last_epoch=-1, + ): + self.milestones = milestones + self.gamma = gamma + super(WarmupStepLrScheduler, self).__init__( + optimizer, warmup_iter, warmup_ratio, warmup, last_epoch + ) + + def get_main_ratio(self): + real_iter = self.last_epoch - self.warmup_iter + ratio = self.gamma ** bisect_right(self.milestones, real_iter) + return ratio diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py new file mode 100755 index 000000000..bf231c5b6 --- /dev/null +++ b/egs/libritts/CODEC/encodec/train.py @@ -0,0 +1,1188 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Author: Zengwei Yao) +# 2024 The Chinese University of HK (Author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import itertools +import logging +import math +import random +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from codec_datamodule import LibriTTSCodecDataModule +from encodec import Encodec +from lhotse.utils import fix_random_seed +from scheduler import WarmupCosineLrScheduler +from torch import nn +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from utils import MetricsTracker, save_checkpoint + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-samples", + type=int, + default=3, + help="Number of samples to generate for tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=500, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="encodec/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lr", type=float, default=3.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=1, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 24000, + "audio_normalization": False, + "lambda_adv": 3.0, # loss scaling coefficient for adversarial loss + "lambda_wav": 0.1, # loss scaling coefficient for waveform loss + "lambda_feat": 4.0, # loss scaling coefficient for feat loss + "lambda_rec": 1.0, # loss scaling coefficient for reconstruction loss + "lambda_com": 1000.0, # loss scaling coefficient for commitment loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + """Get the model based on the configuration.""" + + from discriminators import ( + MultiPeriodDiscriminator, + MultiScaleDiscriminator, + MultiScaleSTFTDiscriminator, + ) + from modules.seanet import SEANetDecoder, SEANetEncoder + from quantization import ResidualVectorQuantizer + + # generator_params = { + # "generator_n_filters": 32, + # "dimension": 512, + # "ratios": [2, 2, 2, 4], + # "target_bandwidths": [7.5, 15], + # "bins": 1024, + # } + # discriminator_params = { + # "stft_discriminator_n_filters": 32, + # "discriminator_epoch_start": 5, + # } + # inference_params = { + # "target_bw": 7.5, + # } + + generator_params = { + "generator_n_filters": 32, + "dimension": 512, + "ratios": [8, 5, 4, 2], + "target_bandwidths": [1.5, 3, 6, 12, 24], + "bins": 1024, + } + discriminator_params = { + "stft_discriminator_n_filters": 32, + "discriminator_epoch_start": 5, + "n_ffts": [1024, 2048, 512], + "hop_lengths": [256, 512, 128], + "win_lengths": [1024, 2048, 512], + } + inference_params = { + "target_bw": 6, + } + + params.update(generator_params) + params.update(discriminator_params) + params.update(inference_params) + + hop_length = np.prod(params.ratios) + n_q = int( + 1000 + * params.target_bandwidths[-1] + // (math.ceil(params.sampling_rate / hop_length) * 10) + ) + + encoder = SEANetEncoder( + n_filters=params.generator_n_filters, + dimension=params.dimension, + ratios=params.ratios, + ) + decoder = SEANetDecoder( + n_filters=params.generator_n_filters, + dimension=params.dimension, + ratios=params.ratios, + ) + quantizer = ResidualVectorQuantizer( + dimension=params.dimension, n_q=n_q, bins=params.bins + ) + + model = Encodec( + params=params, + sampling_rate=params.sampling_rate, + target_bandwidths=params.target_bandwidths, + encoder=encoder, + quantizer=quantizer, + decoder=decoder, + multi_scale_discriminator=None, + multi_period_discriminator=None, + multi_scale_stft_discriminator=MultiScaleSTFTDiscriminator( + n_filters=params.stft_discriminator_n_filters, + n_ffts=params.n_ffts, + hop_lengths=params.hop_lengths, + win_lengths=params.win_lengths, + ), + ) + return model + + +def prepare_input( + params: AttributeDict, + batch: dict, + device: torch.device, + is_training: bool = True, +): + """Parse batch data""" + audio = batch["audio"].to(device, memory_format=torch.contiguous_format) + features = batch["features"].to(device, memory_format=torch.contiguous_format) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + + if is_training: + audio_dims = audio.size(-1) + start_idx = random.randint(0, max(0, audio_dims - params.sampling_rate)) + audio = audio[:, start_idx : params.sampling_rate + start_idx] + else: + # NOTE(zengrui): a very coarse setup + audio = audio[ + :, params.sampling_rate : params.sampling_rate + params.sampling_rate + ] + + if params.audio_normalization: + mean = audio.mean(dim=-1, keepdim=True) + std = audio.std(dim=-1, keepdim=True) + audio = (audio - mean) / (std + 1e-7) + + return audio, audio_lens, features, features_lens + + +def train_discriminator(weight, global_step, threshold=0, value=0.0): + if global_step < threshold: + weight = value + return weight + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model to be trained. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["audio"]) + ( + audio, + audio_lens, + _, + _, + ) = prepare_input(params, batch, device) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + d_weight = train_discriminator( + params.lambda_adv, + params.cur_epoch, + threshold=params.discriminator_epoch_start, + ) + # forward discriminator + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( + speech=audio, + speech_lengths=audio_lens, + return_sample=False, + forward_generator=False, + ) + disc_loss = ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * d_weight + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(disc_loss).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + g_weight = train_discriminator( + params.lambda_adv, + params.cur_epoch, + threshold=params.discriminator_epoch_start, + ) + # forward generator + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + gen_adv_loss = ( + gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss + ) * g_weight + feature_loss = ( + feature_stft_loss + feature_period_loss + feature_scale_loss + ) + reconstruction_loss = ( + params.lambda_wav * wav_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss + ) + gen_loss = ( + gen_adv_loss + + reconstruction_loss + + params.lambda_feat * feature_loss + + params.lambda_com * commit_loss + ) + loss_info["generator_loss"] = gen_loss + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(gen_loss).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + # step per iteration + scheduler_g.step() + scheduler_d.step() + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + # speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + speech_hat_, speech_, _, _ = stats_g["returned_sample"] + + speech_hat_i = speech_hat_[0] + speech_i = speech_[0] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) + tb_writer.add_audio( + f"train/speech_hat_", + speech_hat_i, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + f"train/speech_", + speech_i, + params.batch_idx_train, + params.sampling_rate, + ) + # tb_writer.add_image( + # "train/mel_hat_", + # plot_feature(mel_hat_), + # params.batch_idx_train, + # dataformats="HWC", + # ) + # tb_writer.add_image( + # "train/mel_", + # plot_feature(mel_), + # params.batch_idx_train, + # dataformats="HWC", + # ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None and rank == 0: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + for index in range(params.num_samples): # 3 + speech_hat_i = speech_hat[index] + speech_i = speech[index] + if speech_hat_i.dim() > 1: + speech_hat_i = speech_hat_i.squeeze(0) + speech_i = speech_i.squeeze(0) + tb_writer.add_audio( + f"train/valid_speech_hat_{index}", + speech_hat_i, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + f"train/valid_speech_{index}", + speech_i, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = (None, None) + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + batch_size = len(batch["audio"]) + ( + audio, + audio_lens, + _, + _, + ) = prepare_input(params, batch, device, is_training=False) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + d_weight = train_discriminator( + params.lambda_adv, + params.cur_epoch, + threshold=params.discriminator_epoch_start, + ) + + # forward discriminator + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( + speech=audio, + speech_lengths=audio_lens, + return_sample=False, + forward_generator=False, + ) + disc_loss = ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * d_weight + assert disc_loss.requires_grad is False + loss_info["discriminator_loss"] = disc_loss + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + g_weight = train_discriminator( + params.lambda_adv, + params.cur_epoch, + threshold=params.discriminator_epoch_start, + ) + # forward generator + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + return_sample=False, + ) + gen_adv_loss = ( + gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss + ) * g_weight + feature_loss = feature_stft_loss + feature_period_loss + feature_scale_loss + reconstruction_loss = ( + params.lambda_wav * wav_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss + ) + gen_loss = ( + gen_adv_loss + + reconstruction_loss + + params.lambda_feat * feature_loss + + params.lambda_com * commit_loss + ) + assert gen_loss.requires_grad is False + loss_info["generator_loss"] = gen_loss + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + _, audio_hat = inner_model.inference( + x=audio, target_bw=params.target_bw + ) + returned_sample = (audio_hat, audio) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + _, + _, + ) = prepare_input(params, batch, device) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + ( + disc_stft_real_adv_loss, + disc_stft_fake_adv_loss, + disc_period_real_adv_loss, + disc_period_fake_adv_loss, + disc_scale_real_adv_loss, + disc_scale_fake_adv_loss, + stats_d, + ) = model( + speech=audio, + speech_lengths=audio_lens, + return_sample=False, + forward_generator=False, + ) + loss_d = ( + disc_stft_real_adv_loss + + disc_stft_fake_adv_loss + + disc_period_real_adv_loss + + disc_period_fake_adv_loss + + disc_scale_real_adv_loss + + disc_scale_fake_adv_loss + ) * train_discriminator( + params.lambda_adv, + params.cur_epoch, + threshold=params.discriminator_train_start, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + ( + commit_loss, + gen_stft_adv_loss, + gen_period_adv_loss, + gen_scale_adv_loss, + feature_stft_loss, + feature_period_loss, + feature_scale_loss, + wav_reconstruction_loss, + mel_reconstruction_loss, + stats_g, + ) = model( + speech=audio, + speech_lengths=audio_lens, + forward_generator=True, + return_sample=False, + ) + loss_g = ( + (gen_stft_adv_loss + gen_period_adv_loss + gen_scale_adv_loss) + * train_discriminator( + params.lambda_adv, + 0, + threshold=params.discriminator_epoch_start, + ) + + ( + params.lambda_wav * wav_reconstruction_loss + + params.lambda_rec * mel_reconstruction_loss + ) + + params.lambda_feat + * (feature_stft_loss + feature_period_loss + feature_scale_loss) + + params.lambda_com * commit_loss + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + libritts = LibriTTSCodecDataModule(args) + + if params.full_libri: + train_cuts = libritts.train_all_shuf_cuts() + else: + train_cuts = libritts.train_clean_100_cuts() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + encoder = model.encoder + decoder = model.decoder + quantizer = model.quantizer + multi_scale_discriminator = model.multi_scale_discriminator + multi_period_discriminator = model.multi_period_discriminator + multi_scale_stft_discriminator = model.multi_scale_stft_discriminator + + num_param_e = sum([p.numel() for p in encoder.parameters()]) + logging.info(f"Number of parameters in encoder: {num_param_e}") + num_param_d = sum([p.numel() for p in decoder.parameters()]) + logging.info(f"Number of parameters in decoder: {num_param_d}") + num_param_q = sum([p.numel() for p in quantizer.parameters()]) + logging.info(f"Number of parameters in quantizer: {num_param_q}") + num_param_ds = ( + sum([p.numel() for p in multi_scale_discriminator.parameters()]) + if multi_scale_discriminator is not None + else 0 + ) + logging.info(f"Number of parameters in multi_scale_discriminator: {num_param_ds}") + num_param_dp = ( + sum([p.numel() for p in multi_period_discriminator.parameters()]) + if multi_period_discriminator is not None + else 0 + ) + logging.info(f"Number of parameters in multi_period_discriminator: {num_param_dp}") + num_param_dstft = sum( + [p.numel() for p in multi_scale_stft_discriminator.parameters()] + ) + logging.info( + f"Number of parameters in multi_scale_stft_discriminator: {num_param_dstft}" + ) + logging.info( + f"Total number of parameters: {num_param_e + num_param_d + num_param_q + num_param_ds + num_param_dp + num_param_dstft}" + ) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DDP( + model, + device_ids=[rank], + find_unused_parameters=True, + ) + + optimizer_g = torch.optim.AdamW( + itertools.chain( + encoder.parameters(), + quantizer.parameters(), + decoder.parameters(), + ), + lr=params.lr, + betas=(0.5, 0.9), + ) + discriminator_params = [ + multi_scale_stft_discriminator.parameters(), + ] + if multi_scale_discriminator is not None: + discriminator_params.append(multi_scale_discriminator.parameters()) + if multi_period_discriminator is not None: + discriminator_params.append(multi_period_discriminator.parameters()) + optimizer_d = torch.optim.AdamW( + itertools.chain(*discriminator_params), + lr=params.lr, + betas=(0.5, 0.9), + ) + + scheduler_g = WarmupCosineLrScheduler( + optimizer=optimizer_g, + max_iter=params.num_epochs * 1500, + eta_ratio=0.1, + warmup_iter=params.discriminator_epoch_start * 1500, + warmup_ratio=1e-4, + ) + scheduler_d = WarmupCosineLrScheduler( + optimizer=optimizer_d, + max_iter=params.num_epochs * 1500, + eta_ratio=0.1, + warmup_iter=params.discriminator_epoch_start * 1500, + warmup_ratio=1e-4, + ) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + train_dl = libritts.train_dataloaders( + train_cuts, + world_size=world_size, + rank=rank, + ) + + valid_cuts = libritts.dev_clean_cuts() + valid_dl = libritts.valid_dataloaders( + valid_cuts, + world_size=world_size, + rank=rank, + ) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer_g=optimizer_g, + # optimizer_d=optimizer_d, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriTTSCodecDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libritts/CODEC/encodec/utils.py b/egs/libritts/CODEC/encodec/utils.py new file mode 120000 index 000000000..7c9586776 --- /dev/null +++ b/egs/libritts/CODEC/encodec/utils.py @@ -0,0 +1 @@ +../../../vctk/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/libritts/CODEC/local/compute_spectrogram_libritts.py b/egs/libritts/CODEC/local/compute_spectrogram_libritts.py new file mode 100755 index 000000000..8d864db92 --- /dev/null +++ b/egs/libritts/CODEC/local/compute_spectrogram_libritts.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao,) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the VCTK dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/spectrogram. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from lhotse import CutSet, LilcomChunkyWriter, Spectrogram, SpectrogramConfig +from lhotse.audio import RecordingSet +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Sampling rate of the audio for computing fbank, the default value for LibriTTS is 24000, audio files will be resampled if a different sample rate is provided""", + ) + + return parser.parse_args() + + +def compute_spectrogram_libritts( + dataset: Optional[str] = None, sampling_rate: int = 24000 +): + src_dir = Path("data/manifests") + output_dir = Path("data/spectrogram") + num_jobs = min(32, os.cpu_count()) + + frame_length = 1024 / sampling_rate # (in second) + frame_shift = 256 / sampling_rate # (in second) + use_fft_mag = True + + prefix = "libritts" + suffix = "jsonl.gz" + if dataset is None: + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + else: + dataset_parts = dataset.split(" ", -1) + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=frame_length, + frame_shift=frame_shift, + use_fft_mag=use_fft_mag, + ) + extractor = Spectrogram(config) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if sampling_rate != 24000: + logging.info(f"Resampling audio to {sampling_rate}") + cut_set = cut_set.resample(sampling_rate) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_spectrogram_libritts() diff --git a/egs/libritts/CODEC/local/display_manifest_statistics.py b/egs/libritts/CODEC/local/display_manifest_statistics.py new file mode 100755 index 000000000..ec00e0454 --- /dev/null +++ b/egs/libritts/CODEC/local/display_manifest_statistics.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# 2024 The Chinese Univ. of HK (authors: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + paths = [ + "./data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz", + "./data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz", + "./data/spectrogram/libritts_cuts_train-other-500.jsonl.gz", + "./data/spectrogram/libritts_cuts_dev-clean.jsonl.gz", + "./data/spectrogram/libritts_cuts_dev-other.jsonl.gz", + "./data/spectrogram/libritts_cuts_test-clean.jsonl.gz", + "./data/spectrogram/libritts_cuts_test-other.jsonl.gz", + ] + for path in paths: + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +./data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 33236 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 53:47:18 _ +________________________________________ +_ mean _ 5.8 _ +________________________________________ +_ std _ 4.6 _ +________________________________________ +_ min _ 0.2 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.5 _ +________________________________________ +_ 75% _ 7.9 _ +________________________________________ +_ 99% _ 21.4 _ +________________________________________ +_ 99.5% _ 23.7 _ +________________________________________ +_ 99.9% _ 27.8 _ +________________________________________ +_ max _ 33.2 _ +________________________________________ +_ Recordings available: _ 33236 _ +________________________________________ +_ Features available: _ 33236 _ +________________________________________ +_ Supervisions available: _ 33236 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 53:47:18 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 53:47:18 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz statistics: +_________________________________________ +_ Cuts count: _ 116500 _ +_________________________________________ +_ Total duration (hh:mm:ss) _ 191:17:42 _ +_________________________________________ +_ mean _ 5.9 _ +_________________________________________ +_ std _ 4.6 _ +_________________________________________ +_ min _ 0.1 _ +_________________________________________ +_ 25% _ 2.4 _ +_________________________________________ +_ 50% _ 4.6 _ +_________________________________________ +_ 75% _ 8.1 _ +_________________________________________ +_ 99% _ 21.3 _ +_________________________________________ +_ 99.5% _ 23.4 _ +_________________________________________ +_ 99.9% _ 27.4 _ +_________________________________________ +_ max _ 40.4 _ +_________________________________________ +_ Recordings available: _ 116500 _ +_________________________________________ +_ Features available: _ 116500 _ +_________________________________________ +_ Supervisions available: _ 116500 _ +_________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +___________________________________________________________________ +_ Total speech duration _ 191:17:42 _ 100.00% of recording _ +___________________________________________________________________ +_ Total speaking time duration _ 191:17:42 _ 100.00% of recording _ +___________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +___________________________________________________________________ + +./data/spectrogram/libritts_cuts_train-other-500.jsonl.gz statistics: +_________________________________________ +_ Cuts count: _ 205043 _ +_________________________________________ +_ Total duration (hh:mm:ss) _ 310:04:36 _ +_________________________________________ +_ mean _ 5.4 _ +_________________________________________ +_ std _ 4.4 _ +_________________________________________ +_ min _ 0.1 _ +_________________________________________ +_ 25% _ 2.3 _ +_________________________________________ +_ 50% _ 4.2 _ +_________________________________________ +_ 75% _ 7.3 _ +_________________________________________ +_ 99% _ 20.6 _ +_________________________________________ +_ 99.5% _ 22.8 _ +_________________________________________ +_ 99.9% _ 27.4 _ +_________________________________________ +_ max _ 43.9 _ +_________________________________________ +_ Recordings available: _ 205043 _ +_________________________________________ +_ Features available: _ 205043 _ +_________________________________________ +_ Supervisions available: _ 205043 _ +_________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +___________________________________________________________________ +_ Total speech duration _ 310:04:36 _ 100.00% of recording _ +___________________________________________________________________ +_ Total speaking time duration _ 310:04:36 _ 100.00% of recording _ +___________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +___________________________________________________________________ + +./data/spectrogram/libritts_cuts_dev-clean.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 5736 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 08:58:13 _ +________________________________________ +_ mean _ 5.6 _ +________________________________________ +_ std _ 4.3 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.4 _ +________________________________________ +_ 75% _ 7.8 _ +________________________________________ +_ 99% _ 19.9 _ +________________________________________ +_ 99.5% _ 21.9 _ +________________________________________ +_ 99.9% _ 26.3 _ +________________________________________ +_ max _ 30.1 _ +________________________________________ +_ Recordings available: _ 5736 _ +________________________________________ +_ Features available: _ 5736 _ +________________________________________ +_ Supervisions available: _ 5736 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 08:58:13 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 08:58:13 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/spectrogram/libritts_cuts_dev-other.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 4613 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 06:25:52 _ +________________________________________ +_ mean _ 5.0 _ +________________________________________ +_ std _ 4.1 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.2 _ +________________________________________ +_ 50% _ 3.8 _ +________________________________________ +_ 75% _ 6.5 _ +________________________________________ +_ 99% _ 19.7 _ +________________________________________ +_ 99.5% _ 24.5 _ +________________________________________ +_ 99.9% _ 31.0 _ +________________________________________ +_ max _ 32.6 _ +________________________________________ +_ Recordings available: _ 4613 _ +________________________________________ +_ Features available: _ 4613 _ +________________________________________ +_ Supervisions available: _ 4613 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 06:25:52 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 06:25:52 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/spectrogram/libritts_cuts_test-clean.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 4837 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 08:34:09 _ +________________________________________ +_ mean _ 6.4 _ +________________________________________ +_ std _ 5.1 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 2.4 _ +________________________________________ +_ 50% _ 4.8 _ +________________________________________ +_ 75% _ 8.9 _ +________________________________________ +_ 99% _ 22.6 _ +________________________________________ +_ 99.5% _ 24.4 _ +________________________________________ +_ 99.9% _ 29.6 _ +________________________________________ +_ max _ 36.7 _ +________________________________________ +_ Recordings available: _ 4837 _ +________________________________________ +_ Features available: _ 4837 _ +________________________________________ +_ Supervisions available: _ 4837 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 08:34:09 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 08:34:09 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ + +./data/spectrogram/libritts_cuts_test-other.jsonl.gz statistics: +________________________________________ +_ Cuts count: _ 5120 _ +________________________________________ +_ Total duration (hh:mm:ss) _ 06:41:31 _ +________________________________________ +_ mean _ 4.7 _ +________________________________________ +_ std _ 3.8 _ +________________________________________ +_ min _ 0.3 _ +________________________________________ +_ 25% _ 1.8 _ +________________________________________ +_ 50% _ 3.6 _ +________________________________________ +_ 75% _ 6.5 _ +________________________________________ +_ 99% _ 17.8 _ +________________________________________ +_ 99.5% _ 20.4 _ +________________________________________ +_ 99.9% _ 23.8 _ +________________________________________ +_ max _ 27.3 _ +________________________________________ +_ Recordings available: _ 5120 _ +________________________________________ +_ Features available: _ 5120 _ +________________________________________ +_ Supervisions available: _ 5120 _ +________________________________________ +SUPERVISION custom fields: +Speech duration statistics: +__________________________________________________________________ +_ Total speech duration _ 06:41:31 _ 100.00% of recording _ +__________________________________________________________________ +_ Total speaking time duration _ 06:41:31 _ 100.00% of recording _ +__________________________________________________________________ +_ Total silence duration _ 00:00:01 _ 0.00% of recording _ +__________________________________________________________________ +""" diff --git a/egs/libritts/CODEC/local/validate_manifest.py b/egs/libritts/CODEC/local/validate_manifest.py new file mode 120000 index 000000000..b4d52ebca --- /dev/null +++ b/egs/libritts/CODEC/local/validate_manifest.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/libritts/CODEC/prepare.sh b/egs/libritts/CODEC/prepare.sh new file mode 100755 index 000000000..6a471c3ad --- /dev/null +++ b/egs/libritts/CODEC/prepare.sh @@ -0,0 +1,78 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 +sampling_rate=24000 +nj=32 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriTTS, + # you can create a symlink + # + # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS + # + if [ ! -d $dl_dir/LibriTTS ]; then + lhotse download libritts $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriTTS manifest" + # We assume that you have downloaded the LibriTTS corpus + # to $dl_dir/LibriTTS + mkdir -p data/manifests + if [ ! -e data/manifests/.libritts.done ]; then + lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests + touch data/manifests/.libritts.done + fi +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute Spectrogram for LibriTTS" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.libritts.done ]; then + ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate + touch data/spectrogram/.libritts.done + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 and + # train-other-500 together to form the training set. + if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c /data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz + fi + + if [ ! -e data/spectrogram/.libritts-validated.done ]; then + log "Validating data/spectrogram for LibriTTS" + ./local/validate_manifest.py \ + data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz + touch data/spectrogram/.libritts-validated.done + fi +fi + diff --git a/egs/libritts/CODEC/shared b/egs/libritts/CODEC/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/libritts/CODEC/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From 88bacfb9e6acd06e98b48843da2bdde59a20426b Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 21 Oct 2024 13:51:56 +0800 Subject: [PATCH 24/59] minor fixes for the repo (#1775) * minor fixes for the repo Co-authored-by: Fangjun Kuang --- egs/librispeech/ASR/zipformer/model.py | 2 +- egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py | 1 + icefall/utils.py | 9 ++++----- 3 files changed, 6 insertions(+), 6 deletions(-) create mode 120000 egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index deebb2a75..c7dbe1e0a 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -22,10 +22,10 @@ import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface +from lhotse.dataset import SpecAugment from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask, time_warp -from lhotse.dataset import SpecAugment class AsrModel(nn.Module): diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/icefall/utils.py b/icefall/utils.py index b0a42cefa..0682252f9 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -2282,13 +2282,12 @@ def time_warp( time_warp_factor: Optional[int] = 80, supervision_segments: Optional[torch.Tensor] = None, ): - """Apply time warping on a batch of features - """ + """Apply time warping on a batch of features""" if time_warp_factor is None or time_warp_factor < 1: return features - assert len(features.shape) == 3, ( - "SpecAugment only supports batches of single-channel feature matrices." - ) + assert ( + len(features.shape) == 3 + ), f"SpecAugment only supports batches of single-channel feature matrices. {features.shape}" features = features.clone() if supervision_segments is None: # No supervisions - apply spec augment to full feature matrices. From 37a1420603bed04fe9169d83eee9ea0f7de9b52c Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:16:18 +0800 Subject: [PATCH 25/59] remove incomplete recipe (#1778) Co-authored-by: yifanyeung --- README.md | 2 - .../ASR/local/compute_fbank_musan.py | 1 - .../compute_fbank_peoples_speech_splits.py | 154 ----------- ...compute_fbank_peoples_speech_valid_test.py | 93 ------- egs/peoples_speech/ASR/local/filter_cuts.py | 1 - .../ASR/local/prepare_lang_bpe.py | 1 - .../ASR/local/preprocess_peoples_speech.py | 123 --------- .../ASR/local/train_bpe_model.py | 1 - .../ASR/local/validate_bpe_lexicon.py | 1 - egs/peoples_speech/ASR/prepare.sh | 247 ------------------ egs/peoples_speech/ASR/shared | 1 - 11 files changed, 625 deletions(-) delete mode 120000 egs/peoples_speech/ASR/local/compute_fbank_musan.py delete mode 100755 egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py delete mode 100755 egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py delete mode 120000 egs/peoples_speech/ASR/local/filter_cuts.py delete mode 120000 egs/peoples_speech/ASR/local/prepare_lang_bpe.py delete mode 100755 egs/peoples_speech/ASR/local/preprocess_peoples_speech.py delete mode 120000 egs/peoples_speech/ASR/local/train_bpe_model.py delete mode 120000 egs/peoples_speech/ASR/local/validate_bpe_lexicon.py delete mode 100755 egs/peoples_speech/ASR/prepare.sh delete mode 120000 egs/peoples_speech/ASR/shared diff --git a/README.md b/README.md index 81cfc03ce..57db5eb8d 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,6 @@ for more details. - [LibriSpeech][librispeech] - [Libriheavy][libriheavy] - [Multi-Dialect Broadcast News Arabic Speech Recognition][mgb2] - - [PeopleSpeech][peoplespeech] - [SPGISpeech][spgispeech] - [Switchboard][swbd] - [TIMIT][timit] @@ -375,7 +374,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [libricss]: egs/libricss/SURT [libriheavy]: egs/libriheavy/ASR [mgb2]: egs/mgb2/ASR -[peoplespeech]: egs/peoples_speech/ASR [spgispeech]: egs/spgispeech/ASR [voxpopuli]: egs/voxpopuli/ASR [xbmu-amdo31]: egs/xbmu-amdo31/ASR diff --git a/egs/peoples_speech/ASR/local/compute_fbank_musan.py b/egs/peoples_speech/ASR/local/compute_fbank_musan.py deleted file mode 120000 index 5833f2484..000000000 --- a/egs/peoples_speech/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py deleted file mode 100755 index 6f05b9f8c..000000000 --- a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_splits.py +++ /dev/null @@ -1,154 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -from datetime import datetime -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - KaldifeatFbank, - KaldifeatFbankConfig, - LilcomChunkyWriter, - set_audio_duration_mismatch_tolerance, - set_caching_enabled, -) - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--num-workers", - type=int, - default=20, - help="Number of dataloading workers used for reading the audio.", - ) - - parser.add_argument( - "--batch-duration", - type=float, - default=600.0, - help="The maximum number of audio seconds in a batch." - "Determines batch size dynamically.", - ) - - parser.add_argument( - "--num-splits", - type=int, - required=True, - help="The number of splits of the train subset", - ) - - parser.add_argument( - "--start", - type=int, - default=0, - help="Process pieces starting from this number (included).", - ) - - parser.add_argument( - "--stop", - type=int, - default=-1, - help="Stop processing pieces until this number (excluded).", - ) - - return parser.parse_args() - - -def compute_fbank_peoples_speech_splits(args): - subsets = ("dirty", "dirty_sa", "clean", "clean_sa") - num_splits = args.num_splits - output_dir = f"data/fbank/peoples_speech_train_split" - output_dir = Path(output_dir) - assert output_dir.exists(), f"{output_dir} does not exist!" - - num_digits = 8 - - start = args.start - stop = args.stop - if stop < start: - stop = num_splits - - stop = min(stop, num_splits) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - logging.info(f"device: {device}") - - set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance - set_caching_enabled(False) - - for partition in subsets: - for i in range(start, stop): - idx = f"{i + 1}".zfill(num_digits) - logging.info(f"Processing {partition}: {idx}") - - cuts_path = output_dir / f"peoples_speech_cuts_{partition}.{idx}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{cuts_path} exists - skipping") - continue - - raw_cuts_path = ( - output_dir / f"peoples_speech_cuts_{partition}_raw.{idx}.jsonl.gz" - ) - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Splitting cuts into smaller chunks.") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info("Computing features") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/peoples_speech_feats_{partition}_{idx}", - num_workers=args.num_workers, - batch_duration=args.batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - compute_fbank_peoples_speech_splits(args) - - -if __name__ == "__main__": - main() diff --git a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py b/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py deleted file mode 100755 index 89f43a674..000000000 --- a/egs/peoples_speech/ASR/local/compute_fbank_peoples_speech_valid_test.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the People's Speech dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import argparse -import logging -import os -from pathlib import Path -from typing import Optional - -import torch -from filter_cuts import filter_cuts -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_peoples_speech_valid_test(): - src_dir = Path(f"data/manifests") - output_dir = Path(f"data/fbank") - num_workers = 42 - batch_duration = 600 - - subsets = ("validation", "test") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) - - logging.info(f"device: {device}") - - for partition in subsets: - cuts_path = output_dir / f"peoples_speech_cuts_{partition}.jsonl.gz" - if cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - - raw_cuts_path = output_dir / f"peoples_speech_cuts_{partition}_raw.jsonl.gz" - - logging.info(f"Loading {raw_cuts_path}") - cut_set = CutSet.from_file(raw_cuts_path) - - logging.info("Splitting cuts into smaller chunks") - cut_set = cut_set.trim_to_supervisions( - keep_overlapping=False, min_duration=None - ) - - logging.info("Computing features") - cut_set = cut_set.compute_and_store_features_batch( - extractor=extractor, - storage_path=f"{output_dir}/peoples_speech_feats_{partition}", - num_workers=num_workers, - batch_duration=batch_duration, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info(f"Saving to {cuts_path}") - cut_set.to_file(cuts_path) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_peoples_speech_valid_test() diff --git a/egs/peoples_speech/ASR/local/filter_cuts.py b/egs/peoples_speech/ASR/local/filter_cuts.py deleted file mode 120000 index 27aca1729..000000000 --- a/egs/peoples_speech/ASR/local/filter_cuts.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/local/prepare_lang_bpe.py b/egs/peoples_speech/ASR/local/prepare_lang_bpe.py deleted file mode 120000 index 36b40e7fc..000000000 --- a/egs/peoples_speech/ASR/local/prepare_lang_bpe.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/local/preprocess_peoples_speech.py b/egs/peoples_speech/ASR/local/preprocess_peoples_speech.py deleted file mode 100755 index c5417049f..000000000 --- a/egs/peoples_speech/ASR/local/preprocess_peoples_speech.py +++ /dev/null @@ -1,123 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import re -from pathlib import Path -from typing import Optional - -from lhotse import CutSet, SupervisionSegment -from lhotse.recipes.utils import read_manifests_if_cached - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--dataset", - type=str, - help="""Dataset parts to compute fbank. If None, we will use all""", - ) - - return parser.parse_args() - - -def normalize_text(utt: str) -> str: - utt = re.sub(r"[{0}]+".format("-"), " ", utt) - return re.sub(r"[^a-zA-Z\s]", "", utt).upper() - - -def preprocess_peoples_speech(dataset: Optional[str] = None): - src_dir = Path(f"data/manifests") - output_dir = Path(f"data/fbank") - output_dir.mkdir(exist_ok=True) - - if dataset is None: - dataset_parts = ( - "validation", - "test", - "dirty", - "dirty_sa", - "clean", - "clean_sa", - ) - else: - dataset_parts = dataset.split(" ", -1) - - logging.info("Loading manifest, it may takes 8 minutes") - prefix = f"peoples_speech" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - suffix=suffix, - prefix=prefix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - for partition, m in manifests.items(): - logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" - if raw_cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping") - continue - - logging.info(f"Normalizing text in {partition}") - i = 0 - for sup in m["supervisions"]: - text = str(sup.text) - orig_text = text - sup.text = normalize_text(sup.text) - text = str(sup.text) - if i < 10 and len(orig_text) != len(text): - logging.info( - f"\nOriginal text vs normalized text:\n{orig_text}\n{text}" - ) - i += 1 - - # Create long-recording cut manifests. - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ).resample(16000) - - # Run data augmentation that needs to be done in the - # time domain. - logging.info(f"Saving to {raw_cuts_path}") - cut_set.to_file(raw_cuts_path) - - -def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - logging.info(vars(args)) - preprocess_peoples_speech(dataset=args.dataset) - logging.info("Done") - - -if __name__ == "__main__": - main() diff --git a/egs/peoples_speech/ASR/local/train_bpe_model.py b/egs/peoples_speech/ASR/local/train_bpe_model.py deleted file mode 120000 index 6fad36421..000000000 --- a/egs/peoples_speech/ASR/local/train_bpe_model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/local/validate_bpe_lexicon.py b/egs/peoples_speech/ASR/local/validate_bpe_lexicon.py deleted file mode 120000 index 721bb48e7..000000000 --- a/egs/peoples_speech/ASR/local/validate_bpe_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/peoples_speech/ASR/prepare.sh b/egs/peoples_speech/ASR/prepare.sh deleted file mode 100755 index 3787858d9..000000000 --- a/egs/peoples_speech/ASR/prepare.sh +++ /dev/null @@ -1,247 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - -nj=32 -stage=-1 -stop_stage=100 - -# Split data/set to a number of pieces -# This is to avoid OOM during feature extraction. -num_per_split=4000 - -# We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. -# -# - $dl_dir/peoples_speech -# This directory contains the following files downloaded from -# https://huggingface.co/datasets/MLCommons/peoples_speech -# -# - test -# - train -# - validation -# -# - $dl_dir/musan -# This directory contains the following directories downloaded from -# http://www.openslr.org/17/ -# -# - music -# - noise -# - speech - -dl_dir=$PWD/download - -. shared/parse_options.sh || exit 1 - -# vocab size for sentence piece models. -# It will generate data/lang_bpe_xxx, -# data/lang_bpe_yyy if the array contains xxx, yyy -vocab_sizes=( - # 5000 - # 2000 - # 1000 - 500 -) - -# All files generated by this script are saved in "data". -# You can safely remove "data" and rerun this script to regenerate it. -mkdir -p data - -log() { - # This function is from espnet - local fname=${BASH_SOURCE[1]##*/} - echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" -} - -log "dl_dir: $dl_dir" - -if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Download data" - - # If you have pre-downloaded it to /path/to/peoples_speech, - # you can create a symlink - # - # ln -sfv /path/to/peoples_speech $dl_dir/peoples_speech - # - if [ ! -d $dl_dir/peoples_speech/train ]; then - git lfs install - git clone https://huggingface.co/datasets/MLCommons/peoples_speech - fi - - # If you have pre-downloaded it to /path/to/musan, - # you can create a symlink - # - # ln -sfv /path/to/musan $dl_dir/ - # - if [ ! -d $dl_dir/musan ]; then - lhotse download musan $dl_dir - fi -fi - -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare People's Speech manifest" - # We assume that you have downloaded the People's Speech corpus - # to $dl_dir/peoples_speech - mkdir -p data/manifests - if [ ! -e data/manifests/.peoples_speech.done ]; then - lhotse prepare peoples-speech -j $nj $dl_dir/peoples_speech data/manifests - touch data/manifests/.peoples_speech.done - fi -fi - -if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Prepare musan manifest" - # We assume that you have downloaded the musan corpus - # to data/musan - mkdir -p data/manifests - if [ ! -e data/manifests/.musan.done ]; then - lhotse prepare musan $dl_dir/musan data/manifests - touch data/manifests/.musan.done - fi -fi - -if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Preprocess People's Speech manifest" - mkdir -p data/fbank - if [ ! -e data/fbank/.preprocess_complete ]; then - ./local/preprocess_peoples_speech.py - touch data/fbank/.preprocess_complete - fi -fi - -if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute fbank for valid and test subsets of People's Speech" - if [ ! -e data/fbank/.peoples_speech_valid_test.done ]; then - ./local/compute_fbank_peoples_speech_valid_test.py - touch data/fbank/.peoples_speech_valid_test.done - fi -fi - -if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Split train subset into pieces" - split_dir=data/fbank/peoples_speech_train_split - if [ ! -e $split_dir/.peoples_speech_dirty_split.done ]; then - lhotse split-lazy ./data/fbank/peoples_speech_cuts_dirty_raw.jsonl.gz $split_dir $num_per_split - touch $split_dir/.peoples_speech_dirty_split.done - fi - - if [ ! -e $split_dir/.peoples_speech_dirty_sa_split.done ]; then - lhotse split-lazy ./data/fbank/peoples_speech_cuts_dirty_sa_raw.jsonl.gz $split_dir $num_per_split - touch $split_dir/.peoples_speech_dirty_sa_split.done - fi - - if [ ! -e $split_dir/.peoples_speech_clean_split.done ]; then - lhotse split-lazy ./data/fbank/peoples_speech_cuts_clean_raw.jsonl.gz $split_dir $num_per_split - touch $split_dir/.peoples_speech_clean_split.done - fi - - if [ ! -e $split_dir/.peoples_speech_clean_sa_split.done ]; then - lhotse split-lazy ./data/fbank/peoples_speech_cuts_clean_sa_raw.jsonl.gz $split_dir $num_per_split - touch $split_dir/.peoples_speech_clean_sa_split.done - fi -fi - -if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then - log "Stage 6: Compute features for train subset of People's Speech" - if [ ! -e data/fbank/.peoples_speech_train.done ]; then - ./local/compute_fbank_peoples_speech_splits.py \ - --num-workers $nj \ - --batch-duration 600 \ - --start 0 \ - --num-splits 2000 - touch data/fbank/.peoples_speech_train.done - fi -fi - -if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then - log "Stage 7: Compute fbank for musan" - mkdir -p data/fbank - if [ ! -e data/fbank/.musan.done ]; then - ./local/compute_fbank_musan.py - touch data/fbank/.musan.done - fi -fi - -if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Prepare BPE based lang" - - for vocab_size in ${vocab_sizes[@]}; do - lang_dir=data/lang_bpe_${vocab_size} - mkdir -p $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - file=$( - find "data/fbank/peoples_speech_cuts_dirty_raw.jsonl.gz" - find "data/fbank/peoples_speech_cuts_dirty_sa_raw.jsonl.gz" - find "data/fbank/peoples_speech_cuts_clean_raw.jsonl.gz" - find "data/fbank/peoples_speech_cuts_clean_sa_raw.jsonl.gz" - ) - gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt - - # Ensure space only appears once - sed -i 's/\t/ /g' $lang_dir/transcript_words.txt - sed -i 's/ +/ /g' $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/words.txt ]; then - cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ - | sort -u | sed '/^$/d' > $lang_dir/words.txt - (echo '!SIL'; echo ''; echo ''; ) | - cat - $lang_dir/words.txt | sort | uniq | awk ' - BEGIN { - print " 0"; - } - { - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - if ($1 == "") { - print " is in the vocabulary!" | "cat 1>&2" - exit 1; - } - printf("%s %d\n", $1, NR); - } - END { - printf("#0 %d\n", NR+1); - printf(" %d\n", NR+2); - printf(" %d\n", NR+3); - }' > $lang_dir/words || exit 1; - mv $lang_dir/words $lang_dir/words.txt - fi - - if [ ! -f $lang_dir/bpe.model ]; then - ./local/train_bpe_model.py \ - --lang-dir $lang_dir \ - --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt - fi - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir - - log "Validating $lang_dir/lexicon.txt" - ./local/validate_bpe_lexicon.py \ - --lexicon $lang_dir/lexicon.txt \ - --bpe-model $lang_dir/bpe.model - fi - - if [ ! -f $lang_dir/L.fst ]; then - log "Converting L.pt to L.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L.pt \ - $lang_dir/L.fst - fi - - if [ ! -f $lang_dir/L_disambig.fst ]; then - log "Converting L_disambig.pt to L_disambig.fst" - ./shared/convert-k2-to-openfst.py \ - --olabels aux_labels \ - $lang_dir/L_disambig.pt \ - $lang_dir/L_disambig.fst - fi - done -fi diff --git a/egs/peoples_speech/ASR/shared b/egs/peoples_speech/ASR/shared deleted file mode 120000 index 4c5e91438..000000000 --- a/egs/peoples_speech/ASR/shared +++ /dev/null @@ -1 +0,0 @@ -../../../icefall/shared/ \ No newline at end of file From 05f756390cecdd7a64a1f57168e3185dd974cbe3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 00:59:04 +0800 Subject: [PATCH 26/59] Avoid using lr from checkpoint. (#1781) --- egs/librispeech/ASR/zipformer/optim.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 6f5180e29..8434fab13 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -787,7 +787,9 @@ class LRScheduler(object): is not the optimizer. """ return { - "base_lrs": self.base_lrs, + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, "epoch": self.epoch, "batch": self.batch, } @@ -799,7 +801,12 @@ class LRScheduler(object): state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ + # the things with base_lrs are a work-around for a previous problem + # where base_lrs were written with the state dict. + base_lrs = self.base_lrs self.__dict__.update(state_dict) + self.base_lrs = base_lrs + def get_last_lr(self) -> List[float]: """Return last computed learning rate by current scheduler. Will be a list of float.""" From 7e9eea6dc38898627d4d20dfa3b0daf54cbe1eb1 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 28 Oct 2024 11:53:11 +0800 Subject: [PATCH 27/59] Add pretrained.py for SURT (#1785) --- .../SURT/dprnn_zipformer/pretrained.py | 303 ++++++++++++++++++ egs/libricss/SURT/dprnn_zipformer/train.py | 2 - 2 files changed, 303 insertions(+), 2 deletions(-) create mode 100755 egs/libricss/SURT/dprnn_zipformer/pretrained.py diff --git a/egs/libricss/SURT/dprnn_zipformer/pretrained.py b/egs/libricss/SURT/dprnn_zipformer/pretrained.py new file mode 100755 index 000000000..5f9468957 --- /dev/null +++ b/egs/libricss/SURT/dprnn_zipformer/pretrained.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +""" +Usage: + +1. Download pre-trained models from +https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer + +2. + +./dprnn_zipformer/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --tokens /path/to/data/lang_bpe_500/tokens.txt \ + /path/to/foo.wav +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_surt_model + +from icefall.utils import num_tokens + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="""Path to tokens.txt.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + token_table = k2.SymbolTable.from_file(params.tokens) + + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_surt_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + opts.mel_opts.high_freq = -400 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + B, T, F = features.shape + processed = model.mask_encoder(features) # B,T,F*num_channels + masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) + x_masked = [features * m for m in masks] + + # Recognition + # Concatenate the inputs along the batch axis + h = torch.cat(x_masked, dim=0) + h_lens = feature_lengths.repeat(params.num_channels) + encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) + + if model.joint_encoder_layer is not None: + encoder_out = model.joint_encoder_layer(encoder_out) + + def _group_channels(hyps: List[str]) -> List[List[str]]: + """ + Currently we have a batch of size M*B, where M is the number of + channels and B is the batch size. We need to group the hypotheses + into B groups, each of which contains M hypotheses. + + Example: + hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2'] + _group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']] + """ + assert len(hyps) == B * params.num_channels + out_hyps = [] + for i in range(B): + out_hyps.append(hyps[i::B]) + return out_hyps + + hyps = [] + msg = f"Using {params.method}" + logging.info(msg) + + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + + if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + hyps.append(token_ids_to_words(hyp)) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + s += f"{filename}:\n{hyp}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py index 90d742e7c..148cafd4b 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train.py +++ b/egs/libricss/SURT/dprnn_zipformer/train.py @@ -62,9 +62,7 @@ from asr_datamodule import LibriCssAsrDataModule from decoder import Decoder from dprnn import DPRNN from einops.layers.torch import Rearrange -from graph_pit.loss.optimized import optimized_graph_pit_mse_loss as gpit_mse from joiner import Joiner -from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import LOG_EPSILON, fix_random_seed from model import SURT From 516b4869b32600bed681f6fb1902abb0cbf1399d Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 15:04:04 +0800 Subject: [PATCH 28/59] Add Matcha-TTS (#1773) --- .github/scripts/ljspeech/TTS/run-matcha.sh | 120 +++ .github/scripts/ljspeech/TTS/run.sh | 2 +- .github/workflows/audioset.yml | 6 +- .github/workflows/ljspeech.yml | 12 +- egs/ljspeech/TTS/.gitignore | 7 + egs/ljspeech/TTS/README.md | 118 +++ .../TTS/local/compute_fbank_ljspeech.py | 208 +++++ .../TTS/local/compute_fbank_statistics.py | 84 ++ .../TTS/local/prepare_tokens_ljspeech.py | 28 +- egs/ljspeech/TTS/local/validate_manifest.py | 1 + egs/ljspeech/TTS/matcha/LICENSE | 21 + egs/ljspeech/TTS/matcha/__init__.py | 0 egs/ljspeech/TTS/matcha/audio.py | 92 +++ .../TTS/matcha/compute_fbank_ljspeech.py | 1 + egs/ljspeech/TTS/matcha/export_onnx.py | 196 +++++ .../TTS/matcha/export_onnx_hifigan.py | 110 +++ egs/ljspeech/TTS/matcha/hifigan/LICENSE | 21 + egs/ljspeech/TTS/matcha/hifigan/README.md | 101 +++ egs/ljspeech/TTS/matcha/hifigan/__init__.py | 0 egs/ljspeech/TTS/matcha/hifigan/config.py | 100 +++ egs/ljspeech/TTS/matcha/hifigan/denoiser.py | 71 ++ egs/ljspeech/TTS/matcha/hifigan/env.py | 17 + egs/ljspeech/TTS/matcha/hifigan/meldataset.py | 245 ++++++ egs/ljspeech/TTS/matcha/hifigan/models.py | 406 ++++++++++ egs/ljspeech/TTS/matcha/hifigan/xutils.py | 60 ++ egs/ljspeech/TTS/matcha/inference.py | 199 +++++ egs/ljspeech/TTS/matcha/model.py | 97 +++ egs/ljspeech/TTS/matcha/models/README.md | 3 + egs/ljspeech/TTS/matcha/models/__init__.py | 0 .../TTS/matcha/models/components/__init__.py | 0 .../TTS/matcha/models/components/decoder.py | 459 +++++++++++ .../matcha/models/components/flow_matching.py | 140 ++++ .../matcha/models/components/text_encoder.py | 447 +++++++++++ .../matcha/models/components/transformer.py | 353 +++++++++ egs/ljspeech/TTS/matcha/models/matcha_tts.py | 295 +++++++ .../TTS/matcha/monotonic_align/.gitignore | 3 + .../TTS/matcha/monotonic_align/__init__.py | 23 + .../TTS/matcha/monotonic_align/core.pyx | 49 ++ .../TTS/matcha/monotonic_align/setup.py | 12 + egs/ljspeech/TTS/matcha/onnx_pretrained.py | 204 +++++ egs/ljspeech/TTS/matcha/requirements.txt | 3 + egs/ljspeech/TTS/matcha/tokenizer.py | 1 + egs/ljspeech/TTS/matcha/train.py | 723 ++++++++++++++++++ egs/ljspeech/TTS/matcha/tts_datamodule.py | 341 +++++++++ egs/ljspeech/TTS/matcha/utils.py | 1 + egs/ljspeech/TTS/prepare.sh | 86 ++- icefall/checkpoint.py | 2 +- 47 files changed, 5442 insertions(+), 26 deletions(-) create mode 100755 .github/scripts/ljspeech/TTS/run-matcha.sh create mode 100644 egs/ljspeech/TTS/.gitignore create mode 100755 egs/ljspeech/TTS/local/compute_fbank_ljspeech.py create mode 100755 egs/ljspeech/TTS/local/compute_fbank_statistics.py create mode 100644 egs/ljspeech/TTS/matcha/LICENSE create mode 100644 egs/ljspeech/TTS/matcha/__init__.py create mode 100644 egs/ljspeech/TTS/matcha/audio.py create mode 120000 egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py create mode 100755 egs/ljspeech/TTS/matcha/export_onnx.py create mode 100755 egs/ljspeech/TTS/matcha/export_onnx_hifigan.py create mode 100644 egs/ljspeech/TTS/matcha/hifigan/LICENSE create mode 100644 egs/ljspeech/TTS/matcha/hifigan/README.md create mode 100644 egs/ljspeech/TTS/matcha/hifigan/__init__.py create mode 100644 egs/ljspeech/TTS/matcha/hifigan/config.py create mode 100644 egs/ljspeech/TTS/matcha/hifigan/denoiser.py create mode 100644 egs/ljspeech/TTS/matcha/hifigan/env.py create mode 100644 egs/ljspeech/TTS/matcha/hifigan/meldataset.py create mode 100644 egs/ljspeech/TTS/matcha/hifigan/models.py create mode 100644 egs/ljspeech/TTS/matcha/hifigan/xutils.py create mode 100755 egs/ljspeech/TTS/matcha/inference.py create mode 100644 egs/ljspeech/TTS/matcha/model.py create mode 100644 egs/ljspeech/TTS/matcha/models/README.md create mode 100644 egs/ljspeech/TTS/matcha/models/__init__.py create mode 100644 egs/ljspeech/TTS/matcha/models/components/__init__.py create mode 100644 egs/ljspeech/TTS/matcha/models/components/decoder.py create mode 100644 egs/ljspeech/TTS/matcha/models/components/flow_matching.py create mode 100644 egs/ljspeech/TTS/matcha/models/components/text_encoder.py create mode 100644 egs/ljspeech/TTS/matcha/models/components/transformer.py create mode 100644 egs/ljspeech/TTS/matcha/models/matcha_tts.py create mode 100644 egs/ljspeech/TTS/matcha/monotonic_align/.gitignore create mode 100644 egs/ljspeech/TTS/matcha/monotonic_align/__init__.py create mode 100644 egs/ljspeech/TTS/matcha/monotonic_align/core.pyx create mode 100644 egs/ljspeech/TTS/matcha/monotonic_align/setup.py create mode 100755 egs/ljspeech/TTS/matcha/onnx_pretrained.py create mode 100644 egs/ljspeech/TTS/matcha/requirements.txt create mode 120000 egs/ljspeech/TTS/matcha/tokenizer.py create mode 100755 egs/ljspeech/TTS/matcha/train.py create mode 100644 egs/ljspeech/TTS/matcha/tts_datamodule.py create mode 120000 egs/ljspeech/TTS/matcha/utils.py diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh new file mode 100755 index 000000000..37e1bc320 --- /dev/null +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -0,0 +1,120 @@ +#!/usr/bin/env bash + +set -ex + +apt-get update +apt-get install -y sox + +python3 -m pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html +python3 -m pip install espnet_tts_frontend +python3 -m pip install numba conformer==0.3.2 diffusers librosa + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/ljspeech/TTS + +sed -i.bak s/600/8/g ./prepare.sh +sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh +sed -i.bak s/500/5/g ./prepare.sh +git diff + +function prepare_data() { + # We have created a subset of the data for testing + # + mkdir -p download + pushd download + wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2 + tar xvf LJSpeech-1.1.tar.bz2 + popd + + ./prepare.sh + tree . +} + +function train() { + pushd ./matcha + sed -i.bak s/1500/3/g ./train.py + git diff . + popd + + ./matcha/train.py \ + --exp-dir matcha/exp \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh matcha/exp +} + +function infer() { + + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + + ./matcha/inference.py \ + --epoch 1 \ + --exp-dir ./matcha/exp \ + --tokens data/tokens.txt \ + --vocoder ./generator_v1 \ + --input-text "how are you doing?" \ + --output-wav ./generated.wav + + ls -lh *.wav + soxi ./generated.wav + rm -v ./generated.wav + rm -v generator_v1 +} + +function export_onnx() { + pushd matcha/exp + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/epoch-4000.pt + popd + + pushd data/fbank + rm -v *.json + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/data/cmvn.json + popd + + ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp \ + --epoch 4000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json + + ls -lh *.onnx + + if false; then + # THe CI machine does not have enough memory to run it + # + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + python3 ./matcha/export_onnx_hifigan.py + else + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx + fi + + ls -lh *.onnx + + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_v1.onnx \ + --tokens ./data/tokens.txt \ + --input-text "how are you doing?" \ + --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav + + ls -lh /icefall/*.wav + soxi /icefall/generated-matcha-tts-steps-6-v1.wav +} + +prepare_data +train +infer +export_onnx + +rm -rfv generator_v* matcha/exp diff --git a/.github/scripts/ljspeech/TTS/run.sh b/.github/scripts/ljspeech/TTS/run.sh index 707361782..733a12c47 100755 --- a/.github/scripts/ljspeech/TTS/run.sh +++ b/.github/scripts/ljspeech/TTS/run.sh @@ -22,7 +22,7 @@ git diff function prepare_data() { # We have created a subset of the data for testing # - mkdir download + mkdir -p download pushd download wget -q https://huggingface.co/csukuangfj/ljspeech-subset-for-ci-test/resolve/main/LJSpeech-1.1.tar.bz2 tar xvf LJSpeech-1.1.tar.bz2 diff --git a/.github/workflows/audioset.yml b/.github/workflows/audioset.yml index 280ef8f8e..9c9446239 100644 --- a/.github/workflows/audioset.yml +++ b/.github/workflows/audioset.yml @@ -83,7 +83,7 @@ jobs: ls -lh ./model-onnx/* - name: Upload model to huggingface - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' env: HF_TOKEN: ${{ secrets.HF_TOKEN }} uses: nick-fields/retry@v3 @@ -116,7 +116,7 @@ jobs: rm -rf huggingface - name: Prepare for release - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' shell: bash run: | d=sherpa-onnx-zipformer-audio-tagging-2024-04-09 @@ -125,7 +125,7 @@ jobs: ls -lh - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 with: file_glob: true diff --git a/.github/workflows/ljspeech.yml b/.github/workflows/ljspeech.yml index e202d21b5..7dca96b37 100644 --- a/.github/workflows/ljspeech.yml +++ b/.github/workflows/ljspeech.yml @@ -70,6 +70,7 @@ jobs: cd /icefall git config --global --add safe.directory /icefall + .github/scripts/ljspeech/TTS/run-matcha.sh .github/scripts/ljspeech/TTS/run.sh - name: display files @@ -78,19 +79,13 @@ jobs: ls -lh - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' with: name: generated-test-files-${{ matrix.python-version }}-${{ matrix.torch-version }} path: ./*.wav - - uses: actions/upload-artifact@v4 - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' - with: - name: generated-models-py${{ matrix.python-version }}-torch${{ matrix.torch-version }} - path: ./*.wav - - name: Release exported onnx models - if: matrix.python-version == '3.9' && matrix.torch-version == '2.2.0' && github.event_name == 'push' + if: matrix.python-version == '3.9' && matrix.torch-version == '2.3.0' && github.event_name == 'push' uses: svenstaro/upload-release-action@v2 with: file_glob: true @@ -99,4 +94,3 @@ jobs: repo_name: k2-fsa/sherpa-onnx repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} tag: tts-models - diff --git a/egs/ljspeech/TTS/.gitignore b/egs/ljspeech/TTS/.gitignore new file mode 100644 index 000000000..d5c19797a --- /dev/null +++ b/egs/ljspeech/TTS/.gitignore @@ -0,0 +1,7 @@ +build +core.c +*.so +my-output* +*.wav +*.onnx +generator_v* diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 7b112c12c..1cd6e8fd7 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -101,3 +101,121 @@ export CUDA_VISIBLE_DEVICES=4,5,6,7 # (Note it is killed after `epoch-820.pt`) ``` +# matcha + +[./matcha](./matcha) contains the code for training [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS) + +This recipe provides a Matcha-TTS model trained on the LJSpeech dataset. + +Checkpoints and training logs can be found [here](https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28). +The pull-request for this recipe can be found at + +The training command is given below: +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +python3 ./matcha/train.py \ + --exp-dir ./matcha/exp-new-3/ \ + --num-workers 4 \ + --world-size 4 \ + --num-epochs 4000 \ + --max-duration 1000 \ + --bucketing-sampler 1 \ + --start-epoch 1 +``` + +To inference, use: + +```bash +# Download Hifigan vocoder. We use Hifigan v1 below. You can select from v1, v2, or v3 + +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + +./matcha/inference \ + --exp-dir ./matcha/exp-new-3 \ + --epoch 4000 \ + --tokens ./data/tokens.txt \ + --vocoder ./generator_v1 \ + --input-text "how are you doing?" + --output-wav ./generated.wav +``` + +```bash +soxi ./generated.wav +``` +prints: +``` +Input File : './generated.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:01.29 = 28416 samples ~ 96.6531 CDDA sectors +File Size : 56.9k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +To export the checkpoint to onnx: + +```bash +# export the acoustic model to onnx + +./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-new-3 \ + --epoch 4000 \ + --tokens ./data/tokens.txt +``` + +The above command generate the following files: + + - model-steps-2.onnx + - model-steps-3.onnx + - model-steps-4.onnx + - model-steps-5.onnx + - model-steps-6.onnx + +where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver. + + +To export the Hifigan vocoder to onnx, please use: + +```bash +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + +python3 ./matcha/export_onnx_hifigan.py +``` + +The above command generates 3 files: + + - hifigan_v1.onnx + - hifigan_v2.onnx + - hifigan_v3.onnx + +To use the generated onnx files to generate speech from text, please run: + +```bash +python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_v1.onnx \ + --tokens ./data/tokens.txt \ + --input-text "Ask not what your country can do for you; ask what you can do for your country." \ + --output-wav ./matcha-epoch-4000-step6-hfigian-v1.wav +``` + +```bash +soxi ./matcha-epoch-4000-step6-hfigian-v1.wav + +Input File : './matcha-epoch-4000-step6-hfigian-v1.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:05.46 = 120320 samples ~ 409.252 CDDA sectors +File Size : 241k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +https://github.com/user-attachments/assets/b7c197a6-3870-49c6-90ca-db4d3776869b + diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py new file mode 100755 index 000000000..5152ae675 --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LJSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +from lhotse import CutSet, LilcomChunkyWriter, load_manifest +from lhotse.audio import RecordingSet +from lhotse.features.base import FeatureExtractor, register_extractor +from lhotse.supervision import SupervisionSet +from lhotse.utils import Seconds, compute_num_frames +from matcha.audio import mel_spectrogram + +from icefall.utils import get_executor + + +@dataclass +class MyFbankConfig: + n_fft: int + n_mels: int + sampling_rate: int + hop_length: int + win_length: int + f_min: float + f_max: float + + +@register_extractor +class MyFbank(FeatureExtractor): + + name = "MyFbank" + config_type = MyFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: np.ndarray, + sampling_rate: int, + ) -> torch.Tensor: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + samples = torch.from_numpy(samples) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = ( + mel_spectrogram( + samples, + self.config.n_fft, + self.config.n_mels, + self.config.sampling_rate, + self.config.hop_length, + self.config.win_length, + self.config.f_min, + self.config.f_max, + center=False, + ) + .squeeze() + .t() + ) + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + return mel.numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=4, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + return parser + + +def compute_fbank_ljspeech(num_jobs: int): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + if num_jobs < 1: + num_jobs = os.cpu_count() + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + config = MyFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + + prefix = "ljspeech" + suffix = "jsonl.gz" + partition = "all" + + recordings = load_manifest( + src_dir / f"{prefix}_recordings_{partition}.{suffix}", RecordingSet + ) + supervisions = load_manifest( + src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet + ) + + extractor = MyFbank(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=recordings, supervisions=supervisions + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank_ljspeech(args.num_jobs) diff --git a/egs/ljspeech/TTS/local/compute_fbank_statistics.py b/egs/ljspeech/TTS/local/compute_fbank_statistics.py new file mode 100755 index 000000000..d0232c983 --- /dev/null +++ b/egs/ljspeech/TTS/local/compute_fbank_statistics.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script compute the mean and std of the fbank features. +""" + +import argparse +import json +import logging +from pathlib import Path + +import torch +from lhotse import CutSet, load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + parser.add_argument( + "cmvn", + type=Path, + help="Path to the cmvn.json", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info( + f"Computing fbank mean and std for {manifest} and saving to {args.cmvn}" + ) + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + feat_dim = cut_set[0].features.num_features + num_frames = 0 + s = 0 + sq = 0 + for c in cut_set: + f = torch.from_numpy(c.load_features()) + num_frames += f.shape[0] + s += f.sum() + sq += f.square().sum() + + fbank_mean = s / (num_frames * feat_dim) + fbank_var = sq / (num_frames * feat_dim) - fbank_mean * fbank_mean + print("fbank var", fbank_var) + fbank_std = fbank_var.sqrt() + with open(args.cmvn, "w") as f: + json.dump({"fbank_mean": fbank_mean.item(), "fbank_std": fbank_std.item()}, f) + f.write("\n") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py index 4ba88604c..33a8ac2ab 100755 --- a/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py +++ b/egs/ljspeech/TTS/local/prepare_tokens_ljspeech.py @@ -28,17 +28,33 @@ try: except ModuleNotFoundError as ex: raise RuntimeError(f"{ex}\nPlease run\n pip install espnet_tts_frontend\n") +import argparse + from lhotse import CutSet, load_manifest from piper_phonemize import phonemize_espeak -def prepare_tokens_ljspeech(): - output_dir = Path("data/spectrogram") +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--in-out-dir", + type=Path, + required=True, + help="Input and output directory", + ) + + return parser + + +def prepare_tokens_ljspeech(in_out_dir): prefix = "ljspeech" suffix = "jsonl.gz" partition = "all" - cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + cut_set = load_manifest(in_out_dir / f"{prefix}_cuts_{partition}.{suffix}") new_cuts = [] for cut in cut_set: @@ -56,11 +72,13 @@ def prepare_tokens_ljspeech(): new_cuts.append(cut) new_cut_set = CutSet.from_cuts(new_cuts) - new_cut_set.to_file(output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") + new_cut_set.to_file(in_out_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}") if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - prepare_tokens_ljspeech() + args = get_parser().parse_args() + + prepare_tokens_ljspeech(args.in_out_dir) diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py index 68159ae03..9535ba9f4 100755 --- a/egs/ljspeech/TTS/local/validate_manifest.py +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -33,6 +33,7 @@ import argparse import logging from pathlib import Path +from compute_fbank_ljspeech import MyFbank from lhotse import CutSet, load_manifest_lazy from lhotse.dataset.speech_synthesis import validate_for_tts diff --git a/egs/ljspeech/TTS/matcha/LICENSE b/egs/ljspeech/TTS/matcha/LICENSE new file mode 100644 index 000000000..858018e75 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Shivam Mehta + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/egs/ljspeech/TTS/matcha/__init__.py b/egs/ljspeech/TTS/matcha/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ljspeech/TTS/matcha/audio.py b/egs/ljspeech/TTS/matcha/audio.py new file mode 100644 index 000000000..534331e59 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/audio.py @@ -0,0 +1,92 @@ +# This file is copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[str(fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py new file mode 120000 index 000000000..85255ba0c --- /dev/null +++ b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py @@ -0,0 +1 @@ +../local/compute_fbank_ljspeech.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py new file mode 100755 index 000000000..487ea2995 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script exports a Matcha-TTS model to ONNX. +Note that the model outputs fbank. You need to use a vocoder to convert +it to audio. See also ./export_onnx_hifigan.py +""" + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +import onnx +import torch +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp-new-3", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model, num_steps: int = 5): + super().__init__() + self.model = model + self.num_steps = num_steps + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + temperature: torch.Tensor, + length_scale: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + x: (batch_size, num_tokens), torch.int64 + x_lengths: (batch_size,), torch.int64 + temperature: (1,), torch.float32 + length_scale (1,), torch.float32 + Returns: + audio: (batch_size, num_samples) + + """ + mel = self.model.synthesise( + x=x, + x_lengths=x_lengths, + n_timesteps=self.num_steps, + temperature=temperature, + length_scale=length_scale, + )["mel"] + # mel: (batch_size, feat_dim, num_frames) + + return mel + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + for num_steps in [2, 3, 4, 5, 6]: + logging.info(f"num_steps: {num_steps}") + wrapper = ModelWrapper(model, num_steps=num_steps) + wrapper.eval() + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 1000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + temperature = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) + + opset_version = 14 + filename = f"model-steps-{num_steps}.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, temperature, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "temperature", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) + + meta_data = { + "model_type": "matcha-tts", + "language": "English", + "voice": "en-us", + "has_espeak": 1, + "n_speakers": 1, + "sample_rate": 22050, + "version": 1, + "model_author": "icefall", + "maintainer": "k2-fsa", + "dataset": "LJ Speech", + "num_ode_steps": num_steps, + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py new file mode 100755 index 000000000..63d1fac20 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import logging +from pathlib import Path +from typing import Any, Dict + +import onnx +import torch +from inference import load_vocoder + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward( + self, + mel: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + mel: (batch_size, feat_dim, num_frames), torch.float32 + Returns: + audio: (batch_size, num_samples), torch.float32 + """ + audio = self.model(mel).clamp(-1, 1).squeeze(1) + return audio + + +@torch.inference_mode() +def main(): + # Please go to + # https://github.com/csukuangfj/models/tree/master/hifigan + # to download the following files + model_filenames = ["./generator_v1", "./generator_v2", "./generator_v3"] + + for f in model_filenames: + logging.info(f) + if not Path(f).is_file(): + logging.info(f"Skipping {f} since {f} does not exist") + continue + model = load_vocoder(f) + wrapper = ModelWrapper(model) + wrapper.eval() + num_param = sum([p.numel() for p in wrapper.parameters()]) + logging.info(f"{f}: Number of parameters: {num_param}") + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 80, 100000, dtype=torch.float32) + opset_version = 14 + suffix = f.split("_")[-1] + filename = f"hifigan_{suffix}.onnx" + torch.onnx.export( + wrapper, + x, + filename, + opset_version=opset_version, + input_names=["mel"], + output_names=["audio"], + dynamic_axes={ + "mel": {0: "N", 2: "L"}, + "audio": {0: "N", 1: "L"}, + }, + ) + + meta_data = { + "model_type": "hifigan", + "model_filename": f.split("/")[-1], + "sample_rate": 22050, + "version": 1, + "model_author": "jik876", + "maintainer": "k2-fsa", + "dataset": "LJ Speech", + "url1": "https://github.com/jik876/hifi-gan", + "url2": "https://github.com/csukuangfj/models/tree/master/hifigan", + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ljspeech/TTS/matcha/hifigan/LICENSE b/egs/ljspeech/TTS/matcha/hifigan/LICENSE new file mode 100644 index 000000000..91751daed --- /dev/null +++ b/egs/ljspeech/TTS/matcha/hifigan/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Jungil Kong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/egs/ljspeech/TTS/matcha/hifigan/README.md b/egs/ljspeech/TTS/matcha/hifigan/README.md new file mode 100644 index 000000000..5db258504 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/hifigan/README.md @@ -0,0 +1,101 @@ +# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis + +### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae + +In our [paper](https://arxiv.org/abs/2010.05646), +we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
+We provide our implementation and pretrained models as open source in this repository. + +**Abstract :** +Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. +Although such methods improve the sampling efficiency and memory usage, +their sample quality has not yet reached that of autoregressive and flow-based generative models. +In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. +As speech audio consists of sinusoidal signals with various periods, +we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. +A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method +demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than +real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen +speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times +faster than real-time on CPU with comparable quality to an autoregressive counterpart. + +Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. + +## Pre-requisites + +1. Python >= 3.6 +2. Clone this repository. +3. Install python requirements. Please refer [requirements.txt](requirements.txt) +4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). + And move all wav files to `LJSpeech-1.1/wavs` + +## Training + +``` +python train.py --config config_v1.json +``` + +To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
+Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
+You can change the path by adding `--checkpoint_path` option. + +Validation loss during training with V1 generator.
+![validation loss](./validation_loss.png) + +## Pretrained Model + +You can also use pretrained models we provide.
+[Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
+Details of each folder are as in follows: + +| Folder Name | Generator | Dataset | Fine-Tuned | +| ------------ | --------- | --------- | ------------------------------------------------------ | +| LJ_V1 | V1 | LJSpeech | No | +| LJ_V2 | V2 | LJSpeech | No | +| LJ_V3 | V3 | LJSpeech | No | +| LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | +| VCTK_V1 | V1 | VCTK | No | +| VCTK_V2 | V2 | VCTK | No | +| VCTK_V3 | V3 | VCTK | No | +| UNIVERSAL_V1 | V1 | Universal | No | + +We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. + +## Fine-Tuning + +1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
+ Example: + ` Audio File : LJ001-0001.wav +Mel-Spectrogram File : LJ001-0001.npy` +2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
+3. Run the following command. + ``` + python train.py --fine_tuning True --config config_v1.json + ``` + For other command line options, please refer to the training section. + +## Inference from wav file + +1. Make `test_files` directory and copy wav files into the directory. +2. Run the following command. + ` python inference.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files` by default.
+ You can change the path by adding `--output_dir` option. + +## Inference for end-to-end speech synthesis + +1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), + [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. +2. Run the following command. + ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` + Generated wav files are saved in `generated_files_from_mel` by default.
+ You can change the path by adding `--output_dir` option. + +## Acknowledgements + +We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) +and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. diff --git a/egs/ljspeech/TTS/matcha/hifigan/__init__.py b/egs/ljspeech/TTS/matcha/hifigan/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ljspeech/TTS/matcha/hifigan/config.py b/egs/ljspeech/TTS/matcha/hifigan/config.py new file mode 100644 index 000000000..ecba62fd4 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/hifigan/config.py @@ -0,0 +1,100 @@ +v1 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0004, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 256, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + +# See https://drive.google.com/drive/folders/1bB1tnGIxRN-edlf6k2Rmi1gNCK9Cpcvf +v2 = { + "resblock": "1", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 128, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_initial_channel": 64, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + +# See https://drive.google.com/drive/folders/1KKvuJTLp_gZXC8lug7H_lSXct38_3kx1 +v3 = { + "resblock": "2", + "num_gpus": 0, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [8, 8, 4], + "upsample_kernel_sizes": [16, 16, 8], + "upsample_initial_channel": 256, + "resblock_kernel_sizes": [3, 5, 7], + "resblock_dilation_sizes": [[1, 2], [2, 6], [3, 12]], + "resblock_initial_channel": 128, + "segment_size": 8192, + "num_mels": 80, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 256, + "win_size": 1024, + "sampling_rate": 22050, + "fmin": 0, + "fmax": 8000, + "fmax_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} diff --git a/egs/ljspeech/TTS/matcha/hifigan/denoiser.py b/egs/ljspeech/TTS/matcha/hifigan/denoiser.py new file mode 100644 index 000000000..b9aea61b8 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/hifigan/denoiser.py @@ -0,0 +1,71 @@ +# Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py + +"""Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" +import torch + + +class Denoiser(torch.nn.Module): + """Removes model bias from audio produced with waveglow""" + + def __init__( + self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros" + ): + super().__init__() + self.filter_length = filter_length + self.hop_length = int(filter_length / n_overlap) + self.win_length = win_length + + dtype, device = ( + next(vocoder.parameters()).dtype, + next(vocoder.parameters()).device, + ) + self.device = device + if mode == "zeros": + mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) + elif mode == "normal": + mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) + else: + raise Exception(f"Mode {mode} if not supported") + + def stft_fn(audio, n_fft, hop_length, win_length, window): + spec = torch.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + return_complex=True, + ) + spec = torch.view_as_real(spec) + return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2( + spec[..., -1], spec[..., 0] + ) + + self.stft = lambda x: stft_fn( + audio=x, + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + self.istft = lambda x, y: torch.istft( + torch.complex(x * torch.cos(y), x * torch.sin(y)), + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=torch.hann_window(self.win_length, device=device), + ) + + with torch.no_grad(): + bias_audio = vocoder(mel_input).float().squeeze(0) + bias_spec, _ = self.stft(bias_audio) + + self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) + + @torch.inference_mode() + def forward(self, audio, strength=0.0005): + audio_spec, audio_angles = self.stft(audio) + audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.istft(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/egs/ljspeech/TTS/matcha/hifigan/env.py b/egs/ljspeech/TTS/matcha/hifigan/env.py new file mode 100644 index 000000000..9ea4f948a --- /dev/null +++ b/egs/ljspeech/TTS/matcha/hifigan/env.py @@ -0,0 +1,17 @@ +""" from https://github.com/jik876/hifi-gan """ + +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/egs/ljspeech/TTS/matcha/hifigan/meldataset.py b/egs/ljspeech/TTS/matcha/hifigan/meldataset.py new file mode 100644 index 000000000..6eb15a326 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/hifigan/meldataset.py @@ -0,0 +1,245 @@ +""" from https://github.com/jik876/hifi-gan """ + +import math +import os +import random + +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from librosa.util import normalize +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +def get_dataset_filelist(a): + with open(a.input_training_file, encoding="utf-8") as fi: + training_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 + ] + + with open(a.input_validation_file, encoding="utf-8") as fi: + validation_files = [ + os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") + for x in fi.read().split("\n") + if len(x) > 0 + ] + return training_files, validation_files + + +class MelDataset(torch.utils.data.Dataset): + def __init__( + self, + training_files, + segment_size, + n_fft, + num_mels, + hop_size, + win_size, + sampling_rate, + fmin, + fmax, + split=True, + shuffle=True, + n_cache_reuse=1, + device=None, + fmax_loss=None, + fine_tuning=False, + base_mels_path=None, + ): + self.audio_files = training_files + random.seed(1234) + if shuffle: + random.shuffle(self.audio_files) + self.segment_size = segment_size + self.sampling_rate = sampling_rate + self.split = split + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.fmax_loss = fmax_loss + self.cached_wav = None + self.n_cache_reuse = n_cache_reuse + self._cache_ref_count = 0 + self.device = device + self.fine_tuning = fine_tuning + self.base_mels_path = base_mels_path + + def __getitem__(self, index): + filename = self.audio_files[index] + if self._cache_ref_count == 0: + audio, sampling_rate = load_wav(filename) + audio = audio / MAX_WAV_VALUE + if not self.fine_tuning: + audio = normalize(audio) * 0.95 + self.cached_wav = audio + if sampling_rate != self.sampling_rate: + raise ValueError( + f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR" + ) + self._cache_ref_count = self.n_cache_reuse + else: + audio = self.cached_wav + self._cache_ref_count -= 1 + + audio = torch.FloatTensor(audio) + audio = audio.unsqueeze(0) + + if not self.fine_tuning: + if self.split: + if audio.size(1) >= self.segment_size: + max_audio_start = audio.size(1) - self.segment_size + audio_start = random.randint(0, max_audio_start) + audio = audio[:, audio_start : audio_start + self.segment_size] + else: + audio = torch.nn.functional.pad( + audio, (0, self.segment_size - audio.size(1)), "constant" + ) + + mel = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax, + center=False, + ) + else: + mel = np.load( + os.path.join( + self.base_mels_path, + os.path.splitext(os.path.split(filename)[-1])[0] + ".npy", + ) + ) + mel = torch.from_numpy(mel) + + if len(mel.shape) < 3: + mel = mel.unsqueeze(0) + + if self.split: + frames_per_seg = math.ceil(self.segment_size / self.hop_size) + + if audio.size(1) >= self.segment_size: + mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) + mel = mel[:, :, mel_start : mel_start + frames_per_seg] + audio = audio[ + :, + mel_start + * self.hop_size : (mel_start + frames_per_seg) + * self.hop_size, + ] + else: + mel = torch.nn.functional.pad( + mel, (0, frames_per_seg - mel.size(2)), "constant" + ) + audio = torch.nn.functional.pad( + audio, (0, self.segment_size - audio.size(1)), "constant" + ) + + mel_loss = mel_spectrogram( + audio, + self.n_fft, + self.num_mels, + self.sampling_rate, + self.hop_size, + self.win_size, + self.fmin, + self.fmax_loss, + center=False, + ) + + return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) + + def __len__(self): + return len(self.audio_files) diff --git a/egs/ljspeech/TTS/matcha/hifigan/models.py b/egs/ljspeech/TTS/matcha/hifigan/models.py new file mode 100644 index 000000000..e6da20610 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/hifigan/models.py @@ -0,0 +1,406 @@ +""" from https://github.com/jik876/hifi-gan """ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .xutils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.h = h + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super().__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock1 if h.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for _, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/egs/ljspeech/TTS/matcha/hifigan/xutils.py b/egs/ljspeech/TTS/matcha/hifigan/xutils.py new file mode 100644 index 000000000..eefadcb7a --- /dev/null +++ b/egs/ljspeech/TTS/matcha/hifigan/xutils.py @@ -0,0 +1,60 @@ +""" from https://github.com/jik876/hifi-gan """ + +import glob +import os + +import matplotlib +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") +import matplotlib.pylab as plt + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print(f"Loading '{filepath}'") + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print(f"Saving checkpoint to {filepath}") + torch.save(obj, filepath) + print("Complete.") + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + "????????") + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py new file mode 100755 index 000000000..64abd8e50 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/inference.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +from matcha.hifigan.config import v1, v2, v3 +from matcha.hifigan.denoiser import Denoiser +from matcha.hifigan.models import Generator as HiFiGAN +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp-new-3", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + + return parser + + +def load_vocoder(checkpoint_path): + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform(mel, vocoder, denoiser): + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.cpu().squeeze() + + +def process_text(text: str, tokenizer): + x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.long) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesise( + model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None +): + text_processed = process_text(text, tokenizer) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file(): + raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist") + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.eval() + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + denoiser = Denoiser(vocoder, mode="zeros") + + # Number of ODE Solver steps + n_timesteps = 2 + + # Changes to the speaking rate + length_scale = 1.0 + + # Sampling temperature + temperature = 0.667 + + output = synthesise( + model=model, + tokenizer=tokenizer, + n_timesteps=n_timesteps, + text=params.input_text, + length_scale=length_scale, + temperature=temperature, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/ljspeech/TTS/matcha/model.py b/egs/ljspeech/TTS/matcha/model.py new file mode 100644 index 000000000..6539ffc24 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/model.py @@ -0,0 +1,97 @@ +# This file is copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/model.py +""" from https://github.com/jaywalnut310/glow-tts """ + +import numpy as np +import torch + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) + length = (length / factor).ceil() * factor + if not torch.onnx.is_in_onnx_export(): + return length.int().item() + else: + return length + + +def convert_pad_shape(pad_shape): + inverted_shape = pad_shape[::-1] + pad_shape = [item for sublist in inverted_shape for item in sublist] + return pad_shape + + +def generate_path(duration, mask): + device = duration.device + + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = ( + path + - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[ + :, :-1 + ] + ) + path = path * mask + return path + + +def duration_loss(logw, logw_, lengths): + loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) + return loss + + +def normalize(data, mu, std): + if not isinstance(mu, (float, int)): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, (float, int)): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return (data - mu) / std + + +def denormalize(data, mu, std): + if not isinstance(mu, float): + if isinstance(mu, list): + mu = torch.tensor(mu, dtype=data.dtype, device=data.device) + elif isinstance(mu, torch.Tensor): + mu = mu.to(data.device) + elif isinstance(mu, np.ndarray): + mu = torch.from_numpy(mu).to(data.device) + mu = mu.unsqueeze(-1) + + if not isinstance(std, float): + if isinstance(std, list): + std = torch.tensor(std, dtype=data.dtype, device=data.device) + elif isinstance(std, torch.Tensor): + std = std.to(data.device) + elif isinstance(std, np.ndarray): + std = torch.from_numpy(std).to(data.device) + std = std.unsqueeze(-1) + + return data * std + mu diff --git a/egs/ljspeech/TTS/matcha/models/README.md b/egs/ljspeech/TTS/matcha/models/README.md new file mode 100644 index 000000000..1099ef3c8 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/models/README.md @@ -0,0 +1,3 @@ +# Introduction +Files in this folder are copied from +https://github.com/shivammehta25/Matcha-TTS/tree/main/matcha/models diff --git a/egs/ljspeech/TTS/matcha/models/__init__.py b/egs/ljspeech/TTS/matcha/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ljspeech/TTS/matcha/models/components/__init__.py b/egs/ljspeech/TTS/matcha/models/components/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py new file mode 100644 index 000000000..14d19f5d4 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/models/components/decoder.py @@ -0,0 +1,459 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from conformer import ConformerBlock +from diffusers.models.activations import get_activation +from einops import pack, rearrange, repeat +from matcha.models.components.transformer import BasicTransformerBlock + + +class SinusoidalPosEmb(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(torch.nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = torch.nn.Sequential( + torch.nn.Conv1d(dim, dim_out, 3, padding=1), + torch.nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(torch.nn.Module): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = torch.nn.Sequential( + nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out) + ) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + + self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(nn.Module): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=True, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( # pylint: disable=useless-super-delegation + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + + +class Decoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D( + dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.down_blocks.append( + nn.ModuleList([resnet, transformer_blocks, downsample]) + ) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + resnet = ResnetBlock1D( + dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim + ) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c") + mask_down = rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_down, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c") + mask_mid = rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_mid, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = rearrange(x, "b c t -> b t c") + mask_up = rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_up, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py new file mode 100644 index 000000000..997689b1c --- /dev/null +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -0,0 +1,140 @@ +from abc import ABC + +import torch +import torch.nn.functional as F +from matcha.models.components.decoder import Decoder + + +class BASECFM(torch.nn.Module, ABC): + def __init__( + self, + n_feats, + cfm_params, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 1e-4 + + self.estimator = None + + @torch.inference_mode() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = torch.randn_like(mu) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + return self.solve_euler( + z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond + ) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # I am storing this because I can later plot it by putting a debugger here and saving it to a file + # Or in future might add like a return_all_steps flag + sol = [] + + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + + loss = F.mse_loss( + self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum" + ) / (torch.sum(mask) * u.shape[1]) + return loss, y + + +class CFM(BASECFM): + def __init__( + self, + in_channels, + out_channel, + cfm_params, + decoder_params, + n_spks=1, + spk_emb_dim=64, + ): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) + # Just change the architecture of the estimator here + self.estimator = Decoder( + in_channels=in_channels, out_channels=out_channel, **decoder_params + ) diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py new file mode 100644 index 000000000..ca77cba51 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -0,0 +1,447 @@ +""" from https://github.com/jaywalnut310/glow-tts """ + +import math + +import torch +import torch.nn as nn +from einops import rearrange +from matcha.model import sequence_mask + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.conv_layers = torch.nn.ModuleList() + self.norm_layers = torch.nn.ModuleList() + self.conv_layers.append( + torch.nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = torch.nn.Sequential( + torch.nn.ReLU(), torch.nn.Dropout(p_dropout) + ) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.conv_1 = torch.nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = torch.nn.Conv1d( + filter_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm_2 = LayerNorm(filter_channels) + self.proj = torch.nn.Conv1d(filter_channels, 1, 1) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class RotaryPositionalEmbeddings(nn.Module): + """ + ## RoPE module + + Rotary encoding transforms pairs of features by rotating in the 2D plane. + That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. + Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it + by an angle depending on the position of the token. + """ + + def __init__(self, d: int, base: int = 10_000): + r""" + * `d` is the number of features $d$ + * `base` is the constant used for calculating $\Theta$ + """ + super().__init__() + + self.base = base + self.d = int(d) + self.cos_cached = None + self.sin_cached = None + + def _build_cache(self, x: torch.Tensor): + r""" + Cache $\cos$ and $\sin$ values + """ + # Return if cache is already built + if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: + return + + # Get sequence length + seq_len = x.shape[0] + + # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to( + x.device + ) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.einsum("n,d->nd", seq_idx, theta) + + # Concatenate so that for row $m$ we have + # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ + idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) + + # Cache them + self.cos_cached = idx_theta2.cos()[:, None, None, :] + self.sin_cached = idx_theta2.sin()[:, None, None, :] + + def _neg_half(self, x: torch.Tensor): + # $\frac{d}{2}$ + d_2 = self.d // 2 + + # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) + + def forward(self, x: torch.Tensor): + """ + * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` + """ + # Cache $\cos$ and $\sin$ values + x = rearrange(x, "b h t d -> t b h d") + + self._build_cache(x) + + # Split the features, we can choose to apply rotary embeddings only to a partial set of features. + x_rope, x_pass = x[..., : self.d], x[..., self.d :] + + # Calculate + # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ + neg_half_x = self._neg_half(x_rope) + + x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + ( + neg_half_x * self.sin_cached[: x.shape[0]] + ) + + return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels, + out_channels, + n_heads, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + + # from https://nn.labml.ai/transformers/rope/index.html + self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) + + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) + key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) + value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) + + query = self.query_rotary_pe(query) + key = self.key_rotary_pe(key) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + @staticmethod + def _attention_bias_proximal(length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0 + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.conv_2 = torch.nn.Conv1d( + filter_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size=1, + p_dropout=0.0, + **kwargs, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class TextEncoder(nn.Module): + def __init__( + self, + encoder_type, + encoder_params, + duration_predictor_params, + n_vocab, + n_spks=1, + spk_emb_dim=128, + ): + super().__init__() + self.encoder_type = encoder_type + self.n_vocab = n_vocab + self.n_feats = encoder_params.n_feats + self.n_channels = encoder_params.n_channels + self.spk_emb_dim = spk_emb_dim + self.n_spks = n_spks + + self.emb = torch.nn.Embedding(n_vocab, self.n_channels) + torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) + + if encoder_params.prenet: + self.prenet = ConvReluNorm( + self.n_channels, + self.n_channels, + self.n_channels, + kernel_size=5, + n_layers=3, + p_dropout=0.5, + ) + else: + self.prenet = lambda x, x_mask: x + + self.encoder = Encoder( + encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), + encoder_params.filter_channels, + encoder_params.n_heads, + encoder_params.n_layers, + encoder_params.kernel_size, + encoder_params.p_dropout, + ) + + self.proj_m = torch.nn.Conv1d( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1 + ) + self.proj_w = DurationPredictor( + self.n_channels + (spk_emb_dim if n_spks > 1 else 0), + duration_predictor_params.filter_channels_dp, + duration_predictor_params.kernel_size, + duration_predictor_params.p_dropout, + ) + + def forward(self, x, x_lengths, spks=None): + """Run forward pass to the transformer based encoder and duration predictor + + Args: + x (torch.Tensor): text input + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): text input lengths + shape: (batch_size,) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size,) + + Returns: + mu (torch.Tensor): average output of the encoder + shape: (batch_size, n_feats, max_text_length) + logw (torch.Tensor): log duration predicted by the duration predictor + shape: (batch_size, 1, max_text_length) + x_mask (torch.Tensor): mask for the text input + shape: (batch_size, 1, max_text_length) + """ + x = self.emb(x) * math.sqrt(self.n_channels) + x = torch.transpose(x, 1, -1) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.prenet(x, x_mask) + if self.n_spks > 1: + x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) + x = self.encoder(x, x_mask) + mu = self.proj_m(x) * x_mask + + x_dp = torch.detach(x) + logw = self.proj_w(x_dp, x_mask) + + return mu, logw, x_mask diff --git a/egs/ljspeech/TTS/matcha/models/components/transformer.py b/egs/ljspeech/TTS/matcha/models/components/transformer.py new file mode 100644 index 000000000..a82e560bc --- /dev/null +++ b/egs/ljspeech/TTS/matcha/models/components/transformer.py @@ -0,0 +1,353 @@ +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn +from diffusers.models.attention import ( + GEGLU, + GELU, + AdaLayerNorm, + AdaLayerNormZero, + ApproximateGELU, +) +from diffusers.models.attention_processor import Attention +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, + in_features, + out_features, + alpha=1.0, + alpha_trainable=True, + alpha_logscale=True, + ): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = ( + out_features if isinstance(out_features, list) else [out_features] + ) + self.proj = LoRACompatibleLinear(in_features, out_features) + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) + self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = torch.exp(self.alpha) + beta = torch.exp(self.beta) + else: + alpha = self.alpha + beta = self.beta + + x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(x * alpha), 2 + ) + + return x + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + elif activation_fn == "snakebeta": + act_fn = SnakeBeta(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None + ) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim + if not double_self_attention + else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + # scale_qk=False, # uncomment this to not to use flash attention + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + # Notice that normalization is always applied before the real computation in the following blocks. + # 1. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=encoder_attention_mask + if self.only_cross_attention + else attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = torch.cat( + [ + self.ff(hid_slice) + for hid_slice in norm_hidden_states.chunk( + num_chunks, dim=self._chunk_dim + ) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py new file mode 100644 index 000000000..330d1dc47 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -0,0 +1,295 @@ +import datetime as dt +import math +import random + +import matcha.monotonic_align as monotonic_align +import torch +from matcha.model import ( + denormalize, + duration_loss, + fix_len_compatibility, + generate_path, + sequence_mask, +) +from matcha.models.components.flow_matching import CFM +from matcha.models.components.text_encoder import TextEncoder + + +class MatchaTTS(torch.nn.Module): # 🍵 + def __init__( + self, + n_vocab, + n_spks, + spk_emb_dim, + n_feats, + encoder, + decoder, + cfm, + data_statistics, + out_size, + optimizer=None, + scheduler=None, + prior_loss=True, + use_precomputed_durations=False, + ): + super().__init__() + + # self.save_hyperparameters(logger=False) + + self.n_vocab = n_vocab + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.n_feats = n_feats + self.out_size = out_size + self.prior_loss = prior_loss + self.use_precomputed_durations = use_precomputed_durations + + if n_spks > 1: + self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) + + self.encoder = TextEncoder( + encoder.encoder_type, + encoder.encoder_params, + encoder.duration_predictor_params, + n_vocab, + n_spks, + spk_emb_dim, + ) + + self.decoder = CFM( + in_channels=2 * encoder.encoder_params.n_feats, + out_channel=encoder.encoder_params.n_feats, + cfm_params=cfm, + decoder_params=decoder, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + + if data_statistics is not None: + self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) + self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) + else: + self.register_buffer("mel_mean", torch.tensor(0.0)) + self.register_buffer("mel_std", torch.tensor(1.0)) + + @torch.inference_mode() + def synthesise( + self, x, x_lengths, n_timesteps, temperature=1.0, spks=None, length_scale=1.0 + ): + """ + Generates mel-spectrogram from text. Returns: + 1. encoder outputs + 2. decoder outputs + 3. generated alignment + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + n_timesteps (int): number of steps to use for reverse diffusion in decoder. + temperature (float, optional): controls variance of terminal distribution. + spks (bool, optional): speaker ids. + shape: (batch_size,) + length_scale (float, optional): controls speech pace. + Increase value to slow down generated speech and vice versa. + + Returns: + dict: { + "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Average mel spectrogram generated by the encoder + "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Refined mel spectrogram improved by the CFM + "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), + # Alignment map between text and mel spectrogram + "mel": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), + # Denormalized mel spectrogram + "mel_lengths": torch.Tensor, shape: (batch_size,), + # Lengths of mel spectrograms + "rtf": float, + # Real-time factor + """ + # For RTF computation + t = dt.datetime.now() + + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks.long()) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = y_lengths.max() + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + encoder_outputs = mu_y[:, :, :y_max_length] + + # Generate sample tracing the probability flow + decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, spks) + decoder_outputs = decoder_outputs[:, :, :y_max_length] + + t = (dt.datetime.now() - t).total_seconds() + rtf = t * 22050 / (decoder_outputs.shape[-1] * 256) + + return { + "encoder_outputs": encoder_outputs, + "decoder_outputs": decoder_outputs, + "attn": attn[:, :, :y_max_length], + "mel": denormalize(decoder_outputs, self.mel_mean, self.mel_std), + "mel_lengths": y_lengths, + "rtf": rtf, + } + + def forward( + self, + x, + x_lengths, + y, + y_lengths, + spks=None, + out_size=None, + cond=None, + durations=None, + ): + """ + Computes 3 losses: + 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). + 2. prior loss: loss between mel-spectrogram and encoder outputs. + 3. flow matching loss: loss between mel-spectrogram and decoder outputs. + + Args: + x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. + shape: (batch_size, max_text_length) + x_lengths (torch.Tensor): lengths of texts in batch. + shape: (batch_size,) + y (torch.Tensor): batch of corresponding mel-spectrograms. + shape: (batch_size, n_feats, max_mel_length) + y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. + shape: (batch_size,) + out_size (int, optional): length (in mel's sampling rate) of segment to cut, on which decoder will be trained. + Should be divisible by 2^{num of UNet downsamplings}. Needed to increase batch size. + spks (torch.Tensor, optional): speaker ids. + shape: (batch_size,) + """ + if self.n_spks > 1: + # Get speaker embedding + spks = self.spk_emb(spks) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.encoder(x, x_lengths, spks) + y_max_length = y.shape[-1] + + y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + + if self.use_precomputed_durations: + attn = generate_path(durations.squeeze(1), attn_mask.squeeze(1)) + else: + # Use MAS to find most likely alignment `attn` between text and mel-spectrogram + with torch.no_grad(): + const = -0.5 * math.log(2 * math.pi) * self.n_feats + factor = -0.5 * torch.ones( + mu_x.shape, dtype=mu_x.dtype, device=mu_x.device + ) + y_square = torch.matmul(factor.transpose(1, 2), y**2) + y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) + mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1) + log_prior = y_square - y_mu_double + mu_square + const + + attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) + attn = attn.detach() # b, t_text, T_mel + + # Compute loss between predicted log-scaled durations and those obtained from MAS + # refered to as prior loss in the paper + logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask + dur_loss = duration_loss(logw, logw_, x_lengths) + + # Cut a small segment of mel-spectrogram in order to increase batch size + # - "Hack" taken from Grad-TTS, in case of Grad-TTS, we cannot train batch size 32 on a 24GB GPU without it + # - Do not need this hack for Matcha-TTS, but it works with it as well + if not isinstance(out_size, type(None)): + max_offset = (y_lengths - out_size).clamp(0) + offset_ranges = list( + zip([0] * max_offset.shape[0], max_offset.cpu().numpy()) + ) + out_offset = torch.LongTensor( + [ + torch.tensor(random.choice(range(start, end)) if end > start else 0) + for start, end in offset_ranges + ] + ).to(y_lengths) + attn_cut = torch.zeros( + attn.shape[0], + attn.shape[1], + out_size, + dtype=attn.dtype, + device=attn.device, + ) + y_cut = torch.zeros( + y.shape[0], self.n_feats, out_size, dtype=y.dtype, device=y.device + ) + + y_cut_lengths = [] + for i, (y_, out_offset_) in enumerate(zip(y, out_offset)): + y_cut_length = out_size + (y_lengths[i] - out_size).clamp(None, 0) + y_cut_lengths.append(y_cut_length) + cut_lower, cut_upper = out_offset_, out_offset_ + y_cut_length + y_cut[i, :, :y_cut_length] = y_[:, cut_lower:cut_upper] + attn_cut[i, :, :y_cut_length] = attn[i, :, cut_lower:cut_upper] + + y_cut_lengths = torch.LongTensor(y_cut_lengths) + y_cut_mask = sequence_mask(y_cut_lengths).unsqueeze(1).to(y_mask) + + attn = attn_cut + y = y_cut + y_mask = y_cut_mask + + # Align encoded text with mel-spectrogram and get mu_y segment + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + + # Compute loss of the decoder + diff_loss, _ = self.decoder.compute_loss( + x1=y, mask=y_mask, mu=mu_y, spks=spks, cond=cond + ) + + if self.prior_loss: + prior_loss = torch.sum( + 0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask + ) + prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) + else: + prior_loss = 0 + + return dur_loss, prior_loss, diff_loss, attn + + def get_losses(self, batch): + x, x_lengths = batch["x"], batch["x_lengths"] + y, y_lengths = batch["y"], batch["y_lengths"] + spks = batch["spks"] + + dur_loss, prior_loss, diff_loss, *_ = self( + x=x, + x_lengths=x_lengths, + y=y, + y_lengths=y_lengths, + spks=spks, + out_size=self.out_size, + durations=batch["durations"], + ) + return { + "dur_loss": dur_loss, + "prior_loss": prior_loss, + "diff_loss": diff_loss, + } diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore new file mode 100644 index 000000000..28bdad6b8 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore @@ -0,0 +1,3 @@ +build +core.c +*.so diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py new file mode 100644 index 000000000..5b26fe474 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py @@ -0,0 +1,23 @@ +# Copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py +import numpy as np +import torch +from matcha.monotonic_align.core import maximum_path_c + + +def maximum_path(value, mask): + """Cython optimised version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx new file mode 100644 index 000000000..eabc7f273 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx @@ -0,0 +1,49 @@ +# Copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/core.pyx +import numpy as np + +cimport cython +cimport numpy as np + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py new file mode 100644 index 000000000..df26c633e --- /dev/null +++ b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py @@ -0,0 +1,12 @@ +# Copied from +# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py +from distutils.core import setup + +import numpy +from Cython.Build import cythonize + +setup( + name="monotonic_align", + ext_modules=cythonize("core.pyx"), + include_dirs=[numpy.get_include()], +) diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py new file mode 100755 index 000000000..be34343d3 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import datetime as dt +import logging + +import onnxruntime as ort +import soundfile as sf +import torch +from inference import load_vocoder +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--acoustic-model", + type=str, + required=True, + help="Path to the acoustic model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--vocoder", + type=str, + required=True, + help="Path to the vocoder", + ) + + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + + return parser + + +class OnnxHifiGANModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 3, x.shape + assert x.shape[0] == 1, x.shape + + audio = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + )[0] + + return torch.from_numpy(audio) + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 2 + + self.session_opts = session_opts + self.tokenizer = Tokenizer("./data/tokens.txt") + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 2, x.shape + assert x.shape[0] == 1, x.shape + + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + print("x_lengths", x_lengths) + print("x", x.shape) + + temperature = torch.tensor([1.0], dtype=torch.float32) + length_scale = torch.tensor([1.0], dtype=torch.float32) + + mel = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lengths.numpy(), + self.model.get_inputs()[2].name: temperature.numpy(), + self.model.get_inputs()[3].name: length_scale.numpy(), + }, + )[0] + + return torch.from_numpy(mel) + + +@torch.no_grad() +def main(): + params = get_parser().parse_args() + logging.info(vars(params)) + + model = OnnxModel(params.acoustic_model) + vocoder = OnnxHifiGANModel(params.vocoder) + text = params.input_text + x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.int64) + + start_t = dt.datetime.now() + mel = model(x) + end_t = dt.datetime.now() + + start_t2 = dt.datetime.now() + audio = vocoder(mel) + end_t2 = dt.datetime.now() + + print("audio", audio.shape) # (1, 1, num_samples) + audio = audio.squeeze() + + t = (end_t - start_t).total_seconds() + t2 = (end_t2 - start_t2).total_seconds() + rtf_am = t * 22050 / audio.shape[-1] + rtf_vocoder = t2 * 22050 / audio.shape[-1] + print("RTF for acoustic model ", rtf_am) + print("RTF for vocoder", rtf_vocoder) + + # skip denoiser + sf.write(params.output_wav, audio, 22050, "PCM_16") + logging.info(f"Saved to {params.output_wav}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() + +""" + +|HifiGAN |RTF |#Parameters (M)| +|----------|-----|---------------| +|v1 |0.818| 13.926 | +|v2 |0.101| 0.925 | +|v3 |0.118| 1.462 | + +|Num steps|Acoustic Model RTF| +|---------|------------------| +| 2 | 0.039 | +| 3 | 0.047 | +| 4 | 0.071 | +| 5 | 0.076 | +| 6 | 0.103 | + +""" diff --git a/egs/ljspeech/TTS/matcha/requirements.txt b/egs/ljspeech/TTS/matcha/requirements.txt new file mode 100644 index 000000000..5aadc8984 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/requirements.txt @@ -0,0 +1,3 @@ +conformer==0.3.2 +diffusers # developed using version ==0.25.0 +librosa diff --git a/egs/ljspeech/TTS/matcha/tokenizer.py b/egs/ljspeech/TTS/matcha/tokenizer.py new file mode 120000 index 000000000..44a19b0f4 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/tokenizer.py @@ -0,0 +1 @@ +../vits/tokenizer.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py new file mode 100755 index 000000000..5e713fdfd --- /dev/null +++ b/egs/ljspeech/TTS/matcha/train.py @@ -0,0 +1,723 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + + +import argparse +import json +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.utils import fix_random_seed +from matcha.model import fix_len_compatibility +from matcha.models.matcha_tts import MatchaTTS +from matcha.tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LJSpeechTtsDataModule +from utils import MetricsTracker + +from icefall.checkpoint import load_checkpoint, save_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12335, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_data_statistics(): + return AttributeDict( + { + "mel_mean": 0, + "mel_std": 1, + } + ) + + +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "ljspeech", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + # "batch_size": 64, + # "num_workers": 1, + # "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sample_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": get_data_statistics(), + } + ) + return params + + +def _get_model_params() -> AttributeDict: + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "n_spks": 1, # for ljspeech. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "data_statistics": get_data_statistics(), + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model_args": _get_model_params(), + "data_args": _get_data_params(), + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 10, + "valid_interval": 1500, + "env_info": get_env_info(), + } + ) + return params + + +def get_model(params): + m = MatchaTTS(**params.model_args) + return m + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): + """Parse batch data""" + mel_mean = params.data_args.data_statistics.mel_mean + mel_std_inv = 1 / params.data_args.data_statistics.mel_std + for i in range(batch["features"].shape[0]): + n = batch["features_lens"][i] + batch["features"][i : i + 1, :n, :] = ( + batch["features"][i : i + 1, :n, :] - mel_mean + ) * mel_std_inv + batch["features"][i : i + 1, n:, :] = 0 + + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + max_feature_length = fix_len_compatibility(features.shape[1]) + if max_feature_length > features.shape[1]: + pad = max_feature_length - features.shape[1] + features = torch.nn.functional.pad(features, (0, 0, 0, pad)) + + # features_lens[features_lens.argmax()] += pad + + return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + batch_size = len(batch["tokens"]) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer: Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer=optimizer, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + # audio: (N, T), float32 + # features: (N, T, C), float32 + # audio_lens, (N,), int32 + # features_lens, (N,), int32 + # tokens: List[List[str]], len(tokens) == N + + batch_size = len(batch["tokens"]) + + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + try: + with autocast(enabled=params.use_fp16): + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + loss = sum(losses.values()) + + optimizer.zero_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. + # The _growth_interval of the grad scaler is configurable, + # but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, " + f"batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + "Maximum memory allocated so far is " + f"{torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + logging.info(params) + print(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters: {num_param}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + + logging.info("About to create datamodule") + + ljspeech = LJSpeechTtsDataModule(args) + + train_cuts = ljspeech.train_cuts() + train_dl = ljspeech.train_dataloaders(train_cuts) + + valid_cuts = ljspeech.valid_cuts() + valid_dl = ljspeech.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) + if "sampler" in train_dl: + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py new file mode 100644 index 000000000..8e37fc030 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -0,0 +1,341 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from compute_fbank_ljspeech import MyFbank, MyFbankConfig +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class LJSpeechTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MyFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + pin_memory=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MyFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=True, + pin_memory=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MyFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "ljspeech_cuts_test.jsonl.gz" + ) diff --git a/egs/ljspeech/TTS/matcha/utils.py b/egs/ljspeech/TTS/matcha/utils.py new file mode 120000 index 000000000..c2144f8e0 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/utils.py @@ -0,0 +1 @@ +../vits/utils.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 9ed0f93fd..6f16f8d47 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -5,7 +5,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -stage=0 +stage=-1 stop_stage=100 dl_dir=$PWD/download @@ -31,7 +31,19 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then python3 setup.py build_ext --inplace cd ../../ else - log "monotonic_align lib already built" + log "monotonic_align lib for vits already built" + fi + + if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then + pushd matcha/monotonic_align + python3 setup.py build + mv -v build/lib.*/matcha/monotonic_align/core.*.so . + rm -rf build + rm core.c + ls -lh + popd + else + log "monotonic_align lib for matcha-tts already built" fi fi @@ -63,7 +75,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then - log "Stage 2: Compute spectrogram for LJSpeech" + log "Stage 2: Compute spectrogram for LJSpeech (used by ./vits)" mkdir -p data/spectrogram if [ ! -e data/spectrogram/.ljspeech.done ]; then ./local/compute_spectrogram_ljspeech.py @@ -71,7 +83,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ ! -e data/spectrogram/.ljspeech-validated.done ]; then - log "Validating data/spectrogram for LJSpeech" + log "Validating data/spectrogram for LJSpeech (used by ./vits)" python3 ./local/validate_manifest.py \ data/spectrogram/ljspeech_cuts_all.jsonl.gz touch data/spectrogram/.ljspeech-validated.done @@ -79,13 +91,13 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare phoneme tokens for LJSpeech" + log "Stage 3: Prepare phoneme tokens for LJSpeech (used by ./vits)" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then - ./local/prepare_tokens_ljspeech.py + ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/spectrogram mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ data/spectrogram/ljspeech_cuts_all.jsonl.gz touch data/spectrogram/.ljspeech_with_token.done @@ -93,7 +105,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Split the LJSpeech cuts into train, valid and test sets" + log "Stage 4: Split the LJSpeech cuts into train, valid and test sets (used by vits)" if [ ! -e data/spectrogram/.ljspeech_split.done ]; then lhotse subset --last 600 \ data/spectrogram/ljspeech_cuts_all.jsonl.gz \ @@ -126,3 +138,63 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then ./local/prepare_token_file.py --tokens data/tokens.txt fi fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Generate fbank (used by ./matcha)" + mkdir -p data/fbank + if [ ! -e data/fbank/.ljspeech.done ]; then + ./local/compute_fbank_ljspeech.py + touch data/fbank/.ljspeech.done + fi + + if [ ! -e data/fbank/.ljspeech-validated.done ]; then + log "Validating data/fbank for LJSpeech (used by ./matcha)" + python3 ./local/validate_manifest.py \ + data/fbank/ljspeech_cuts_all.jsonl.gz + touch data/fbank/.ljspeech-validated.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare phoneme tokens for LJSpeech (used by ./matcha)" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html, + # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/fbank/.ljspeech_with_token.done ]; then + ./local/prepare_tokens_ljspeech.py --in-out-dir ./data/fbank + mv data/fbank/ljspeech_cuts_with_tokens_all.jsonl.gz \ + data/fbank/ljspeech_cuts_all.jsonl.gz + touch data/fbank/.ljspeech_with_token.done + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Split the LJSpeech cuts into train, valid and test sets (used by ./matcha)" + if [ ! -e data/fbank/.ljspeech_split.done ]; then + lhotse subset --last 600 \ + data/fbank/ljspeech_cuts_all.jsonl.gz \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz \ + data/fbank/ljspeech_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/fbank/ljspeech_cuts_validtest.jsonl.gz \ + data/fbank/ljspeech_cuts_test.jsonl.gz + + rm data/fbank/ljspeech_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/ljspeech_cuts_all.jsonl.gz | wc -l) - 600 )) + lhotse subset --first $n \ + data/fbank/ljspeech_cuts_all.jsonl.gz \ + data/fbank/ljspeech_cuts_train.jsonl.gz + touch data/fbank/.ljspeech_split.done + fi +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compute fbank mean and std (used by ./matcha)" + if [ ! -f ./data/fbank/cmvn.json ]; then + ./local/compute_fbank_statistics.py ./data/fbank/ljspeech_cuts_train.jsonl.gz ./data/fbank/cmvn.json + fi +fi diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 308a06b1f..d31ce1301 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -90,7 +90,7 @@ def save_checkpoint( if params: for k, v in params.items(): - assert k not in checkpoint + assert k not in checkpoint, k checkpoint[k] = v torch.save(checkpoint, filename) From f23c8ce9dd2ab0d95f6dc002c0d7e30e7238a5ac Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 15:50:49 +0800 Subject: [PATCH 29/59] Fix CI test for gigaspeech (#1787) --- ...eech-pruned-transducer-stateless2-2022-05-12.sh | 14 +++++++++++--- ...speech-lstm-transducer-stateless2-2022-09-03.sh | 4 ++-- .github/workflows/run-gigaspeech-2022-05-13.yml | 6 +++--- ...peech-lstm-transducer-stateless2-2022-09-03.yml | 12 ++++++------ .../ASR/pruned_transducer_stateless2/decode.py | 2 +- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh index c9e798a68..8d98fca1e 100755 --- a/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh +++ b/.github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh @@ -19,7 +19,7 @@ repo=$(basename $repo_url) echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_NAME}" == x"workflow_dispatch" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then mkdir -p pruned_transducer_stateless2/exp ln -s $PWD/$repo/exp/pretrained-iter-3488000-avg-20.pt pruned_transducer_stateless2/exp/epoch-999.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ @@ -29,8 +29,16 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == ls -lh data/fbank ls -lh pruned_transducer_stateless2/exp - ln -sf data/fbank/cuts_DEV.jsonl.gz data/fbank/gigaspeech_cuts_DEV.jsonl.gz - ln -sf data/fbank/cuts_TEST.jsonl.gz data/fbank/gigaspeech_cuts_TEST.jsonl.gz + pushd data/fbank + curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/cuts_DEV.jsonl.gz + curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/cuts_TEST.jsonl.gz + curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/feats_DEV.lca + curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/feats_TEST.lca + + ln -sf cuts_DEV.jsonl.gz gigaspeech_cuts_DEV.jsonl.gz + ln -sf cuts_TEST.jsonl.gz gigaspeech_cuts_TEST.jsonl.gz + popd + log "Decoding dev and test" diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh index d547bdd45..8f5a8dbb9 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh @@ -162,7 +162,7 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then --ngram-lm-scale -0.16 fi -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_NAME}" == x"workflow_dispatch" ]]; then mkdir -p lstm_transducer_stateless2/exp ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ @@ -175,7 +175,7 @@ if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then # use a small value for decoding with CPU max_duration=100 - for method in greedy_search fast_beam_search modified_beam_search; do + for method in greedy_search fast_beam_search; do log "Decoding with $method" ./lstm_transducer_stateless2/decode.py \ diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index 2c1d44fbf..9fd05d94a 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -41,7 +41,7 @@ concurrency: jobs: run_gigaspeech_2022_05_13: - if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event_name == 'workflow_dispatch' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -106,7 +106,7 @@ jobs: .github/scripts/run-gigaspeech-pruned-transducer-stateless2-2022-05-12.sh - name: Display decoding results for gigaspeech pruned_transducer_stateless2 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || github.event.label.name == 'run-decode' shell: bash run: | cd egs/gigaspeech/ASR/ @@ -122,7 +122,7 @@ jobs: - name: Upload decoding results for gigaspeech pruned_transducer_stateless2 uses: actions/upload-artifact@v4 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' || github.event.label.name == 'run-decode' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-gigaspeech-pruned_transducer_stateless2-2022-05-12 path: egs/gigaspeech/ASR/pruned_transducer_stateless2/exp/ diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 6a3f4eb40..1b222da2d 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -24,7 +24,7 @@ concurrency: jobs: run_librispeech_lstm_transducer_stateless2_2022_09_03: - if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' runs-on: ${{ matrix.os }} strategy: matrix: @@ -116,7 +116,7 @@ jobs: .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh - name: Display decoding results for lstm_transducer_stateless2 - if: github.event_name == 'schedule' + if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch' shell: bash run: | cd egs/librispeech/ASR @@ -130,9 +130,9 @@ jobs: find fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - echo "===modified beam search===" - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + # echo "===modified beam search===" + # find modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + # find modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Display decoding results for lstm_transducer_stateless2 if: github.event.label.name == 'shallow-fusion' @@ -159,7 +159,7 @@ jobs: - name: Upload decoding results for lstm_transducer_stateless2 uses: actions/upload-artifact@v4 - if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' + if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' || github.event_name == 'workflow_dispatch' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-lstm_transducer_stateless2-2022-09-03 path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index ef430302d..f1efebcb9 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -260,7 +260,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. """ - device = model.device + device = next(model.parameters()).device feature = batch["inputs"] assert feature.ndim == 3 From 6c7863c2f805afb162dd39c4d3eb279f93106b66 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 29 Oct 2024 22:26:25 +0800 Subject: [PATCH 30/59] Fix CI tests (#1788) Use numpy<2.0 --- .github/scripts/docker/Dockerfile | 8 +++-- .../scripts/docker/generate_build_matrix.py | 32 ++++++++----------- .../run-gigaspeech-zipformer-2023-10-17.sh | 18 +++++++++-- .../run-gigaspeech-zipformer-2023-10-17.yml | 22 ++++++------- .github/workflows/yesno.yml | 1 + 5 files changed, 45 insertions(+), 36 deletions(-) diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index 15f49f826..94e8d8e1e 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -31,12 +31,15 @@ LABEL github_repo="https://github.com/k2-fsa/icefall" # Install dependencies RUN pip install --no-cache-dir \ - torch==${TORCH_VERSION} torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/cpu/torch_stable.html \ + torch==${TORCH_VERSION}+cpu -f https://download.pytorch.org/whl/torch \ + torchaudio==${TORCHAUDIO_VERSION}+cpu -f https://download.pytorch.org/whl/torchaudio \ k2==${_K2_VERSION} -f https://k2-fsa.github.io/k2/cpu.html \ \ git+https://github.com/lhotse-speech/lhotse \ kaldifeat==${_KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cpu.html \ + conformer==0.3.2 \ cython \ + diffusers \ dill \ espnet_tts_frontend \ graphviz \ @@ -45,10 +48,11 @@ RUN pip install --no-cache-dir \ kaldialign \ kaldifst \ kaldilm \ + librosa \ matplotlib \ multi_quantization \ numba \ - numpy \ + "numpy<2.0" \ onnxoptimizer \ onnxsim \ onnx \ diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 08281151e..9c53a38df 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -43,9 +43,11 @@ def get_torchaudio_version(torch_version): def get_matrix(): - k2_version = "1.24.4.dev20240223" - kaldifeat_version = "1.25.4.dev20240223" - version = "20240905" + k2_version = "1.24.4.dev20241029" + kaldifeat_version = "1.25.5.dev20241029" + version = "20241029" + + # torchaudio 2.5.0 does not support python 3.13 python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] # torch_version += ["1.13.0", "1.13.1"] @@ -56,6 +58,7 @@ def get_matrix(): torch_version += ["2.3.0", "2.3.1"] torch_version += ["2.4.0"] torch_version += ["2.4.1"] + torch_version += ["2.5.0"] matrix = [] for p in python_version: @@ -69,25 +72,16 @@ def get_matrix(): if version_gt(p, "3.11") and not version_gt(t, "2.1"): continue + if version_gt(p, "3.12") and not version_gt(t, "2.4"): + continue + + if version_gt(t, "2.4") and version_gt("3.10", p): + # torch>=2.5 requires python 3.10 + continue + k2_version_2 = k2_version kaldifeat_version_2 = kaldifeat_version - if t == "2.2.2": - k2_version_2 = "1.24.4.dev20240328" - kaldifeat_version_2 = "1.25.4.dev20240329" - elif t == "2.3.0": - k2_version_2 = "1.24.4.dev20240425" - kaldifeat_version_2 = "1.25.4.dev20240425" - elif t == "2.3.1": - k2_version_2 = "1.24.4.dev20240606" - kaldifeat_version_2 = "1.25.4.dev20240606" - elif t == "2.4.0": - k2_version_2 = "1.24.4.dev20240725" - kaldifeat_version_2 = "1.25.4.dev20240725" - elif t == "2.4.1": - k2_version_2 = "1.24.4.dev20240905" - kaldifeat_version_2 = "1.25.4.dev20240905" - matrix.append( { "k2-version": k2_version_2, diff --git a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh index 329896ef6..438edd3b1 100755 --- a/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh +++ b/.github/scripts/run-gigaspeech-zipformer-2023-10-17.sh @@ -129,20 +129,34 @@ done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_NAME}" == x"workflow_dispatch" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then mkdir -p zipformer/exp ln -s $PWD/$repo/exp/pretrained.pt zipformer/exp/epoch-30.pt + mkdir -p data ln -s $PWD/$repo/data/lang_bpe_500 data/ ls -lh data ls -lh zipformer/exp + mkdir -p data/fbank + pushd data/fbank + + curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/cuts_DEV.jsonl.gz + curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/cuts_TEST.jsonl.gz + curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/feats_DEV.lca + curl -SL -O https://huggingface.co/csukuangfj/giga-dev-dataset-fbank/resolve/main/data/fbank/feats_TEST.lca + + ln -sf cuts_DEV.jsonl.gz gigaspeech_cuts_DEV.jsonl.gz + ln -sf cuts_TEST.jsonl.gz gigaspeech_cuts_TEST.jsonl.gz + + popd + log "Decoding test-clean and test-other" # use a small value for decoding with CPU max_duration=100 - for method in greedy_search fast_beam_search modified_beam_search; do + for method in greedy_search; do log "Decoding with $method" ./zipformer/decode.py \ diff --git a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml index 4ecc2aea0..48322e75c 100644 --- a/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml +++ b/.github/workflows/run-gigaspeech-zipformer-2023-10-17.yml @@ -90,10 +90,6 @@ jobs: GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - mkdir -p egs/gigaspeech/ASR/data - ln -sfv ~/tmp/fbank-libri egs/gigaspeech/ASR/data/fbank - ls -lh egs/gigaspeech/ASR/data/* - sudo apt-get -qq install git-lfs tree export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH @@ -112,7 +108,7 @@ jobs: tag: asr-models - name: Display decoding results for gigaspeech zipformer - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' || github.event_name == 'workflow_dispatch' shell: bash run: | cd egs/gigaspeech/ASR/ @@ -124,17 +120,17 @@ jobs: find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - echo "===fast_beam_search===" - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - - echo "===modified beam search===" - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 - find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + # echo "===fast_beam_search===" + # find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + # find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + # + # echo "===modified beam search===" + # find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + # find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 - name: Upload decoding results for gigaspeech zipformer uses: actions/upload-artifact@v4 - if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' || github.event_name == 'workflow_dispatch' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-latest-cpu-zipformer-2022-11-11 path: egs/gigaspeech/ASR/zipformer/exp/ diff --git a/.github/workflows/yesno.yml b/.github/workflows/yesno.yml index de822b33f..a9d65516f 100644 --- a/.github/workflows/yesno.yml +++ b/.github/workflows/yesno.yml @@ -61,5 +61,6 @@ jobs: python3 -m torch.utils.collect_env python3 -m k2.version + pip list .github/scripts/yesno/ASR/run.sh From d513d456b8b51e6a52a2005efee71394d767ab48 Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Wed, 30 Oct 2024 10:14:34 +0800 Subject: [PATCH 31/59] Add prefix beam search and corresponding decoding methods (#1786) * Add prefix beam search / shallow fussion / hotwords in librispeech ctc decode * Add librispeech cr-ctc prefix beam search results --- egs/librispeech/ASR/RESULTS.md | 9 +- egs/librispeech/ASR/zipformer/ctc_decode.py | 238 ++++++- icefall/decode.py | 674 +++++++++++++++++++- icefall/utils.py | 11 + 4 files changed, 908 insertions(+), 24 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 6a669f072..e5a82dfda 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -153,6 +153,7 @@ You can use to deploy it. | decoding method | test-clean | test-other | comment | |--------------------------------------|------------|------------|---------------------| | ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 | +| ctc-prefix-beam-search | 2.52 | 5.85 | --epoch 50 --avg 25 | The training command using 2 32G-V100 GPUs is: ```bash @@ -184,7 +185,7 @@ export CUDA_VISIBLE_DEVICES="0,1" The decoding command is: ```bash export CUDA_VISIBLE_DEVICES="0" -for m in ctc-greedy-search; do +for m in ctc-greedy-search ctc-prefix-beam-search; do ./zipformer/ctc_decode.py \ --epoch 50 \ --avg 25 \ @@ -212,6 +213,7 @@ You can use to deploy it. | decoding method | test-clean | test-other | comment | |--------------------------------------|------------|------------|---------------------| | ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 | +| ctc-prefix-beam-search | 2.1 | 4.61 | --epoch 50 --avg 24 | The training command using 4 32G-V100 GPUs is: ```bash @@ -238,7 +240,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" The decoding command is: ```bash export CUDA_VISIBLE_DEVICES="0" -for m in ctc-greedy-search; do +for m in ctc-greedy-search ctc-prefix-beam-search; do ./zipformer/ctc_decode.py \ --epoch 50 \ --avg 24 \ @@ -262,6 +264,7 @@ You can use to deploy it. | decoding method | test-clean | test-other | comment | |--------------------------------------|------------|------------|---------------------| | ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 | +| ctc-prefix-beam-search | 2.02 | 4.35 | --epoch 50 --avg 26 | The training command using 2 80G-A100 GPUs is: ```bash @@ -292,7 +295,7 @@ export CUDA_VISIBLE_DEVICES="0,1" The decoding command is: ```bash export CUDA_VISIBLE_DEVICES="0" -for m in ctc-greedy-search; do +for m in ctc-greedy-search ctc-prefix-beam-search; do ./zipformer/ctc_decode.py \ --epoch 50 \ --avg 26 \ diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 9db429959..156989b78 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -111,6 +111,7 @@ Usage: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -129,8 +130,14 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) + +from icefall.context_graph import ContextGraph, ContextState + from icefall.decode import ( ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, get_lattice, nbest_decoding, nbest_oracle, @@ -140,7 +147,11 @@ from icefall.decode import ( rescore_with_n_best_list, rescore_with_whole_lattice, ) + +from icefall.ngram_lm import NgramLm, NgramLmStateCost from icefall.lexicon import Lexicon +from icefall.lm_wrapper import LmScorer + from icefall.utils import ( AttributeDict, get_texts, @@ -255,6 +266,12 @@ def get_parser(): lattice, rescore them with the attention decoder. - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM rescored lattice, rescore them with the attention decoder. + - (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best + path of the n paths is the decoding result. + - (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with + the given beam, rescore them with the attention decoder. + - (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during + beam search, LODR and hotwords are also supported in this decoding method. """, ) @@ -280,6 +297,23 @@ def get_parser(): """, ) + parser.add_argument( + "--nnlm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--nnlm-scale", + type=float, + default=0, + help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion. + Used only when `--use-shallow-fusion` is set to True. + """, + ) + parser.add_argument( "--hlg-scale", type=float, @@ -297,11 +331,52 @@ def get_parser(): """, ) + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--lodr-ngram", + type=str, + help="The path to the lodr ngram", + ) + + parser.add_argument( + "--lodr-lm-scale", + type=float, + default=0, + help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.", + ) + + parser.add_argument( + "--context-score", + type=float, + default=0, + help=""" + The bonus score of each token for the context biasing words/phrases. + 0 means don't use contextual biasing. + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + parser.add_argument( "--skip-scoring", type=str2bool, default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""" + help="""Skip scoring, but still save the ASR output (for eval sets).""", ) add_model_arguments(parser) @@ -314,11 +389,12 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "beam": 4, # for prefix-beam-search } ) return params @@ -333,6 +409,9 @@ def decode_one_batch( batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -377,10 +456,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = params.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -411,6 +487,51 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} + if params.decoding_method == "ctc-prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + NNLM=NNLM, + LODR_lm=LODR_lm, + LODR_lm_scale=params.lodr_lm_scale, + context_graph=context_graph, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -584,6 +705,9 @@ def decode_dataset( bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -634,6 +758,9 @@ def decode_dataset( batch=batch, word_table=word_table, G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) for name, hyps in hyps_dict.items(): @@ -664,9 +791,7 @@ def save_asr_output( """ for key, results in results_dict.items(): - recogs_filename = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recogs_filename, texts=results) @@ -680,7 +805,8 @@ def save_wer_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.decoding_method in ( - "attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", ): # Set it to False since there are too many logs. enable_log = False @@ -721,6 +847,7 @@ def save_wer_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -735,8 +862,11 @@ def main(): set_caching_enabled(True) # lhotse assert params.decoding_method in ( - "ctc-greedy-search", "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "1best", "nbest", "nbest-rescoring", @@ -762,6 +892,16 @@ def main(): params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + if params.nnlm_scale != 0: + params.suffix += f"_nnlm-scale-{params.nnlm_scale}" + if params.lodr_lm_scale != 0: + params.suffix += f"_lodr-scale-{params.lodr_lm_scale}" + if params.context_score != 0: + params.suffix += f"_context_score-{params.context_score}" + if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -771,6 +911,7 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + params.device = device logging.info(f"Device: {device}") logging.info(params) @@ -786,14 +927,24 @@ def main(): params.sos_id = 1 if params.decoding_method in [ - "ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" + "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "attention-decoder-rescoring-no-ngram", ]: HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: @@ -844,7 +995,8 @@ def main(): G = k2.Fsa.from_dict(d) if params.decoding_method in [ - "whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later @@ -858,6 +1010,51 @@ def main(): else: G = None + # only load the neural network LM if required + NNLM = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.nnlm_scale != 0 + ): + NNLM = LmScorer( + lm_type=params.nnlm_type, + params=params, + device=device, + lm_scale=params.nnlm_scale, + ) + NNLM.to(device) + NNLM.eval() + + LODR_lm = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.lodr_lm_scale != 0 + ): + assert os.path.exists( + params.lodr_ngram + ), f"LODR ngram does not exists, given path : {params.lodr_ngram}" + logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}") + LODR_lm = NgramLm( + params.lodr_ngram, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {LODR_lm.lm.num_states}") + + context_graph = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.context_score != 0 + ): + assert os.path.exists( + params.context_file + ), f"context_file does not exists, given path : {params.context_file}" + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(bpe_model.encode(line.strip())) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + logging.info("About to create model") model = get_model(params) @@ -967,6 +1164,9 @@ def main(): bpe_model=bpe_model, word_table=lexicon.word_table, G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) save_asr_output( diff --git a/icefall/decode.py b/icefall/decode.py index dd3af1e99..5f90ee168 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -15,11 +16,16 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Union +from dataclasses import dataclass, field +from multiprocessing.pool import Pool +from typing import Dict, List, Optional, Tuple, Union import k2 import torch +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer from icefall.utils import add_eos, add_sos, get_texts DEFAULT_LM_SCALE = [ @@ -1497,3 +1503,667 @@ def ctc_greedy_search( hyps = [h[h != blank_id].tolist() for h in hyps] return hyps + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] = field(default_factory=list) + + # The log prob of ys that ends with blank token. + # It contains only one entry. + log_prob_blank: torch.Tensor = torch.zeros(1, dtype=torch.float32) + + # The log prob of ys that ends with non blank token. + # It contains only one entry. + log_prob_non_blank: torch.Tensor = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + # The lm score of ys + # May contain external LM score (including LODR score) and contextual biasing score + # It contains only one entry + lm_score: torch.Tensor = torch.zeros(1, dtype=torch.float32) + + # the lm log_probs for next token given the history ys + # The number of elements should be equal to vocabulary size. + lm_log_probs: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # LODR (N-gram LM) state + LODR_state: Optional[NgramLmStateCost] = None + + # N-gram LM state + Ngram_state: Optional[NgramLmStateCost] = None + + # Context graph state + context_state: Optional[ContextState] = None + + # This is the total score of current path, acoustic plus external LM score. + @property + def tot_score(self) -> torch.Tensor: + return self.log_prob + self.lm_score + + # This is only the probability from model output (i.e External LM score not included). + @property + def log_prob(self) -> torch.Tensor: + return torch.logaddexp(self.log_prob_non_blank, self.log_prob_blank) + + @property + def key(self) -> tuple: + """Return a tuple representation of self.ys""" + return tuple(self.ys) + + def clone(self) -> "Hypothesis": + return Hypothesis( + ys=self.ys, + log_prob_blank=self.log_prob_blank, + log_prob_non_blank=self.log_prob_non_blank, + timestamp=self.timestamp, + lm_log_probs=self.lm_log_probs, + lm_score=self.lm_score, + state=self.state, + LODR_state=self.LODR_state, + Ngram_state=self.Ngram_state, + context_state=self.context_state, + ) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[tuple, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[tuple, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob_blank, hyp.log_prob_blank, out=old_hyp.log_prob_blank + ) + torch.logaddexp( + old_hyp.log_prob_non_blank, + hyp.log_prob_non_blank, + out=old_hyp.log_prob_non_blank, + ) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `tot_score`. + Args: + length_norm: + If True, the `tot_score` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `tot_score`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.tot_score / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.tot_score) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + Caution: + `self` is modified **in-place**. + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose tot_score is less than threshold. + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `tot_score` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.tot_score > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + Args: + length_norm: + If True, the `tot_score` of a hypothesis is normalized by the + number of tokens in it. + """ + hyps = list(self._data.items()) + + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].tot_score / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].tot_score, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: tuple): + return key in self._data + + def __getitem__(self, key: tuple): + return self._data[key] + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(str(s)) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def _step_worker( + log_probs: torch.Tensor, + indexes: torch.Tensor, + B: HypothesisList, + beam: int = 4, + blank_id: int = 0, + nnlm_scale: float = 0, + LODR_lm_scale: float = 0, + context_graph: Optional[ContextGraph] = None, +) -> HypothesisList: + """The worker to decode one step. + Args: + log_probs: + topk log_probs of current step (i.e. the kept tokens of first pass pruning), + the shape is (beam,) + topk_indexes: + The indexes of the topk_values above, the shape is (beam,) + B: + An instance of HypothesisList containing the kept hypothesis. + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + lm_scale: + The scale of nn lm. + LODR_lm_scale: + The scale of the LODR_lm + context_graph: + A ContextGraph instance containing contextual phrases. + Return: + Returns the updated HypothesisList. + """ + A = list(B) + B = HypothesisList() + for h in range(len(A)): + hyp = A[h] + for k in range(log_probs.size(0)): + log_prob, index = log_probs[k], indexes[k] + new_token = index.item() + update_prefix = False + new_hyp = hyp.clone() + if new_token == blank_id: + # Case 0: *a + ε => *a + # *aε + ε => *a + # Prefix does not change, update log_prob of blank + new_hyp.log_prob_non_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + new_hyp.log_prob_blank = hyp.log_prob + log_prob + B.add(new_hyp) + elif len(hyp.ys) > 0 and hyp.ys[-1] == new_token: + # Case 1: *a + a => *a + # Prefix does not change, update log_prob of non_blank + new_hyp.log_prob_non_blank = hyp.log_prob_non_blank + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + B.add(new_hyp) + + # Case 2: *aε + a => *aa + # Prefix changes, update log_prob of blank + new_hyp = hyp.clone() + # Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys = hyp.ys + [new_token] + new_hyp.log_prob_non_blank = hyp.log_prob_blank + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + update_prefix = True + else: + # Case 3: *a + b => *ab, *aε + b => *ab + # Prefix changes, update log_prob of non_blank + # Caution: DO NOT use append, as clone is shallow copy + new_hyp.ys = hyp.ys + [new_token] + new_hyp.log_prob_non_blank = hyp.log_prob + log_prob + new_hyp.log_prob_blank = torch.tensor( + [float("-inf")], dtype=torch.float32 + ) + update_prefix = True + + if update_prefix: + lm_score = hyp.lm_score + if hyp.lm_log_probs is not None: + lm_score = lm_score + hyp.lm_log_probs[new_token] * nnlm_scale + new_hyp.lm_log_probs = None + + if context_graph is not None and hyp.context_state is not None: + ( + context_score, + new_context_state, + matched_state, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + lm_score = lm_score + context_score + new_hyp.context_state = new_context_state + + if hyp.LODR_state is not None: + state_cost = hyp.LODR_state.forward_one_step(new_token) + # calculate the score of the latest token + current_ngram_score = state_cost.lm_score - hyp.LODR_state.lm_score + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.LODR_state.lm_score, + ) + lm_score = lm_score + LODR_lm_scale * current_ngram_score + new_hyp.LODR_state = state_cost + + new_hyp.lm_score = lm_score + B.add(new_hyp) + B = B.topk(beam) + return B + + +def _sequence_worker( + topk_values: torch.Tensor, + topk_indexes: torch.Tensor, + B: HypothesisList, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, +) -> HypothesisList: + """The worker to decode one sequence. + Args: + topk_values: + topk log_probs of model output (i.e. the kept tokens of first pass pruning), + the shape is (T, beam) + topk_indexes: + The indexes of the topk_values above, the shape is (T, beam) + B: + An instance of HypothesisList containing the kept hypothesis. + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + Return: + Returns the updated HypothesisList. + """ + B.add(Hypothesis()) + for j in range(encoder_out_lens): + log_probs, indexes = topk_values[j], topk_indexes[j] + B = _step_worker(log_probs, indexes, B, beam, blank_id) + return B + + +def ctc_prefix_beam_search( + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, + process_pool: Optional[Pool] = None, + return_nbest: Optional[bool] = False, +) -> Union[List[List[int]], List[HypothesisList]]: + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks". + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + process_pool: + The process pool for parallel decoding, if not provided, it will use all + you cpu cores by default. + return_nbest: + If true, return a list of HypothesisList, return a list of list of decoded token ids otherwise. + """ + batch_size, num_frames, vocab_size = ctc_output.shape + + # TODO: using a larger beam for first pass pruning + topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) + topk_values = topk_values.cpu() + topk_indexes = topk_indexes.cpu() + + B = [HypothesisList() for _ in range(batch_size)] + + pool = Pool() if process_pool is None else process_pool + arguments = [] + for i in range(batch_size): + arguments.append( + ( + topk_values[i], + topk_indexes[i], + B[i], + encoder_out_lens[i].item(), + beam, + blank_id, + ) + ) + async_results = pool.starmap_async(_sequence_worker, arguments) + B = list(async_results.get()) + if process_pool is None: + pool.close() + pool.join() + if return_nbest: + return B + else: + best_hyps = [b.get_most_probable() for b in B] + return [hyp.ys for hyp in best_hyps] + + +def ctc_prefix_beam_search_shallow_fussion( + ctc_output: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + blank_id: int = 0, + LODR_lm: Optional[NgramLm] = None, + LODR_lm_scale: Optional[float] = 0, + NNLM: Optional[LmScorer] = None, + context_graph: Optional[ContextGraph] = None, +) -> List[List[int]]: + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add + nervous language model shallow fussion, it also supports contextual + biasing with a given grammar. + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + LODR_lm: + A low order n-gram LM, whose score will be subtracted during shallow fusion + LODR_lm_scale: + The scale of the LODR_lm + LM: + A neural net LM, e.g an RNNLM or transformer LM + context_graph: + A ContextGraph instance containing contextual phrases. + Return: + Returns a list of list of decoded token ids. + """ + batch_size, num_frames, vocab_size = ctc_output.shape + # TODO: using a larger beam for first pass pruning + topk_values, topk_indexes = ctc_output.topk(beam) # (B, T, beam) + topk_values = topk_values.cpu() + topk_indexes = topk_indexes.cpu() + encoder_out_lens = encoder_out_lens.tolist() + device = ctc_output.device + + nnlm_scale = 0 + init_scores = None + init_states = None + if NNLM is not None: + nnlm_scale = NNLM.lm_scale + sos_id = getattr(NNLM, "sos_id", 1) + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_scores, init_states = NNLM.score_token(sos_token, lens) + init_scores, init_states = init_scores.cpu(), ( + init_states[0].cpu(), + init_states[1].cpu(), + ) + + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[], + log_prob_non_blank=torch.tensor([float("-inf")], dtype=torch.float32), + log_prob_blank=torch.zeros(1, dtype=torch.float32), + lm_score=torch.zeros(1, dtype=torch.float32), + state=init_states, + lm_log_probs=None if init_scores is None else init_scores.reshape(-1), + LODR_state=None if LODR_lm is None else NgramLmStateCost(LODR_lm), + context_state=None if context_graph is None else context_graph.root, + ) + ) + for j in range(num_frames): + for i in range(batch_size): + if j < encoder_out_lens[i]: + log_probs, indexes = topk_values[i][j], topk_indexes[i][j] + B[i] = _step_worker( + log_probs=log_probs, + indexes=indexes, + B=B[i], + beam=beam, + blank_id=blank_id, + nnlm_scale=nnlm_scale, + LODR_lm_scale=LODR_lm_scale, + context_graph=context_graph, + ) + if NNLM is None: + continue + # update lm_log_probs + token_list = [] # a list of list + hs = [] + cs = [] + indexes = [] # (batch_idx, key) + for batch_idx, hyps in enumerate(B): + for hyp in hyps: + if hyp.lm_log_probs is None: # those hyps that prefix changes + if NNLM.lm_type == "rnn": + token_list.append([hyp.ys[-1]]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append([sos_id] + hyp.ys[:]) + indexes.append((batch_idx, hyp.key)) + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if NNLM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + state = None + + scores, lm_states = NNLM.score_token(tokens_to_score, x_lens, state) + scores, lm_states = scores.cpu(), (lm_states[0].cpu(), lm_states[1].cpu()) + assert scores.size(0) == len(indexes), (scores.size(0), len(indexes)) + for i in range(scores.size(0)): + batch_idx, key = indexes[i] + B[batch_idx][key].lm_log_probs = scores[i] + if NNLM.lm_type == "rnn": + state = ( + lm_states[0][:, i, :].unsqueeze(1), + lm_states[1][:, i, :].unsqueeze(1), + ) + B[batch_idx][key].state = state + + # finalize context_state, if the matched contexts do not reach final state + # we need to add the score on the corresponding backoff arc + if context_graph is not None: + for hyps in B: + for hyp in hyps: + context_score, new_context_state = context_graph.finalize( + hyp.context_state + ) + hyp.lm_score += context_score + hyp.context_state = new_context_state + + best_hyps = [b.get_most_probable() for b in B] + return [hyp.ys for hyp in best_hyps] + + +def ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output: torch.Tensor, + attention_decoder: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 8, + blank_id: int = 0, + attention_scale: Optional[float] = None, + process_pool: Optional[Pool] = None, +): + """Implement prefix search decoding in "Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks" and add + attention decoder rescoring. + Args: + ctc_output: + The output of ctc head (log probability), the shape is (B, T, V) + attention_decoder: + The attention decoder. + encoder_out: + The output of encoder, the shape is (B, T, D) + encoder_out_lens: + The lengths (frames) of sequences after subsampling, the shape is (B,) + beam: + The number of hypothesis to be kept at each step. + blank_id: + The id of blank in the vocabulary. + attention_scale: + The scale of attention decoder score, if not provided it will search in + a default list (see the code below). + process_pool: + The process pool for parallel decoding, if not provided, it will use all + you cpu cores by default. + """ + # List[HypothesisList] + nbest = ctc_prefix_beam_search( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + beam=beam, + blank_id=blank_id, + return_nbest=True, + ) + + device = ctc_output.device + + hyp_shape = get_hyps_shape(nbest).to(device) + hyp_to_utt_map = hyp_shape.row_ids(1).to(torch.long) + # the shape of encoder_out is (N, T, C), so we use axis=0 here + expanded_encoder_out = encoder_out.index_select(0, hyp_to_utt_map) + expanded_encoder_out_lens = encoder_out_lens.index_select(0, hyp_to_utt_map) + + nbest = [list(x) for x in nbest] + token_ids = [] + scores = [] + for hyps in nbest: + for hyp in hyps: + token_ids.append(hyp.ys) + scores.append(hyp.log_prob.reshape(1)) + scores = torch.cat(scores).to(device) + + nll = attention_decoder.nll( + encoder_out=expanded_encoder_out, + encoder_out_lens=expanded_encoder_out_lens, + token_ids=token_ids, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + if attention_scale is None: + attention_scale_list = [0.01, 0.05, 0.08] + attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] + attention_scale_list += [5.0, 6.0, 7.0, 8.0, 9.0] + else: + attention_scale_list = [attention_scale] + + ans = dict() + + start_indexes = hyp_shape.row_splits(1)[0:-1] + for a_scale in attention_scale_list: + tot_scores = scores + a_scale * attention_scores + ragged_tot_scores = k2.RaggedTensor(hyp_shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + max_indexes = max_indexes - start_indexes + max_indexes = max_indexes.cpu() + best_path = [nbest[i][max_indexes[i]].ys for i in range(len(max_indexes))] + key = f"attention_scale_{a_scale}" + ans[key] = best_path + return ans diff --git a/icefall/utils.py b/icefall/utils.py index 0682252f9..41eebadd4 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -19,8 +19,10 @@ import argparse import collections +import json import logging import os +import pathlib import random import re import subprocess @@ -180,6 +182,15 @@ class AttributeDict(dict): return raise AttributeError(f"No such attribute '{key}'") + def __str__(self, indent: int = 2): + tmp = {} + for k, v in self.items(): + # PosixPath is ont JSON serializable + if isinstance(v, pathlib.Path) or isinstance(v, torch.device): + v = str(v) + tmp[k] = v + return json.dumps(tmp, indent=indent, sort_keys=True) + def encode_supervisions( supervisions: dict, From 87cadfcd2ee23e1709ef74353c1679e5d965b14d Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 30 Oct 2024 21:14:12 +0800 Subject: [PATCH 32/59] fixed formatting issue (#1791) * isort fixed formatting issue --- egs/librispeech/ASR/zipformer/ctc_decode.py | 6 +----- icefall/decode.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 156989b78..fe9347b95 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -130,9 +130,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) - from icefall.context_graph import ContextGraph, ContextState - from icefall.decode import ( ctc_greedy_search, ctc_prefix_beam_search, @@ -147,11 +145,9 @@ from icefall.decode import ( rescore_with_n_best_list, rescore_with_whole_lattice, ) - -from icefall.ngram_lm import NgramLm, NgramLmStateCost from icefall.lexicon import Lexicon from icefall.lm_wrapper import LmScorer - +from icefall.ngram_lm import NgramLm, NgramLmStateCost from icefall.utils import ( AttributeDict, get_texts, diff --git a/icefall/decode.py b/icefall/decode.py index 5f90ee168..5d836bd48 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -24,8 +24,8 @@ import k2 import torch from icefall.context_graph import ContextGraph, ContextState -from icefall.ngram_lm import NgramLm, NgramLmStateCost from icefall.lm_wrapper import LmScorer +from icefall.ngram_lm import NgramLm, NgramLmStateCost from icefall.utils import add_eos, add_sos, get_texts DEFAULT_LM_SCALE = [ From 119e1ce3e8b645011c5173c8a6e55c2308569fa4 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:54:12 +0800 Subject: [PATCH 33/59] fix str2bool (#1792) --- egs/libriheavy/ASR/local/prepare_manifest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/egs/libriheavy/ASR/local/prepare_manifest.py b/egs/libriheavy/ASR/local/prepare_manifest.py index d7e184d86..a57a3749d 100755 --- a/egs/libriheavy/ASR/local/prepare_manifest.py +++ b/egs/libriheavy/ASR/local/prepare_manifest.py @@ -20,6 +20,8 @@ import json import sys from pathlib import Path +from icefall.utils import str2bool + def simple_cleanup(text: str) -> str: table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") @@ -34,7 +36,7 @@ def main(): ), "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR KEEP_CUSTOM_FIELDS" fname = Path(sys.argv[1]).name oname = Path(sys.argv[2]) / fname - keep_custom_fields = bool(sys.argv[3]) + keep_custom_fields = str2bool(sys.argv[3]) with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout: for line in fin: cut = json.loads(line) From 66225fbe3323c3fec34c14dbe2dc920c39bc87cf Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 1 Nov 2024 15:33:13 +0800 Subject: [PATCH 34/59] VITS recipe for LibriTTS corpus (#1776) --- README.md | 3 + egs/libritts/CODEC/encodec/train.py | 18 +- egs/libritts/CODEC/prepare.sh | 5 +- egs/libritts/TTS/README.md | 51 + .../TTS/local/compute_spectrogram_libritts.py | 1 + egs/libritts/TTS/local/prepare_token_file.py | 1 + .../TTS/local/prepare_tokens_libritts.py | 89 ++ egs/libritts/TTS/local/validate_manifest.py | 1 + egs/libritts/TTS/prepare.sh | 134 +++ egs/libritts/TTS/shared | 1 + egs/libritts/TTS/vits/duration_predictor.py | 1 + egs/libritts/TTS/vits/flow.py | 1 + egs/libritts/TTS/vits/generator.py | 1 + egs/libritts/TTS/vits/hifigan.py | 1 + egs/libritts/TTS/vits/infer.py | 280 +++++ egs/libritts/TTS/vits/loss.py | 1 + egs/libritts/TTS/vits/monotonic_align | 1 + egs/libritts/TTS/vits/posterior_encoder.py | 1 + egs/libritts/TTS/vits/residual_coupling.py | 1 + egs/libritts/TTS/vits/test_onnx.py | 141 +++ egs/libritts/TTS/vits/text_encoder.py | 1 + egs/libritts/TTS/vits/tokenizer.py | 1 + egs/libritts/TTS/vits/train.py | 1015 +++++++++++++++++ egs/libritts/TTS/vits/transform.py | 1 + egs/libritts/TTS/vits/tts_datamodule.py | 432 +++++++ egs/libritts/TTS/vits/utils.py | 1 + egs/libritts/TTS/vits/vits.py | 1 + egs/libritts/TTS/vits/wavenet.py | 1 + egs/ljspeech/TTS/vits/generator.py | 7 +- egs/ljspeech/TTS/vits/train.py | 4 +- egs/ljspeech/TTS/vits/vits.py | 6 + egs/vctk/TTS/vits/train.py | 4 +- 32 files changed, 2190 insertions(+), 17 deletions(-) create mode 100644 egs/libritts/TTS/README.md create mode 120000 egs/libritts/TTS/local/compute_spectrogram_libritts.py create mode 120000 egs/libritts/TTS/local/prepare_token_file.py create mode 100755 egs/libritts/TTS/local/prepare_tokens_libritts.py create mode 120000 egs/libritts/TTS/local/validate_manifest.py create mode 100755 egs/libritts/TTS/prepare.sh create mode 120000 egs/libritts/TTS/shared create mode 120000 egs/libritts/TTS/vits/duration_predictor.py create mode 120000 egs/libritts/TTS/vits/flow.py create mode 120000 egs/libritts/TTS/vits/generator.py create mode 120000 egs/libritts/TTS/vits/hifigan.py create mode 100755 egs/libritts/TTS/vits/infer.py create mode 120000 egs/libritts/TTS/vits/loss.py create mode 120000 egs/libritts/TTS/vits/monotonic_align create mode 120000 egs/libritts/TTS/vits/posterior_encoder.py create mode 120000 egs/libritts/TTS/vits/residual_coupling.py create mode 100755 egs/libritts/TTS/vits/test_onnx.py create mode 120000 egs/libritts/TTS/vits/text_encoder.py create mode 120000 egs/libritts/TTS/vits/tokenizer.py create mode 100755 egs/libritts/TTS/vits/train.py create mode 120000 egs/libritts/TTS/vits/transform.py create mode 100644 egs/libritts/TTS/vits/tts_datamodule.py create mode 120000 egs/libritts/TTS/vits/utils.py create mode 120000 egs/libritts/TTS/vits/vits.py create mode 120000 egs/libritts/TTS/vits/wavenet.py diff --git a/README.md b/README.md index 57db5eb8d..0e550ffb1 100644 --- a/README.md +++ b/README.md @@ -333,6 +333,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt - [LJSpeech][ljspeech] - [VCTK][vctk] + - [LibriTTS][libritts_tts] ### Supported Models @@ -372,6 +373,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [commonvoice]: egs/commonvoice/ASR [csj]: egs/csj/ASR [libricss]: egs/libricss/SURT +[libritts_asr]: egs/libritts/ASR [libriheavy]: egs/libriheavy/ASR [mgb2]: egs/mgb2/ASR [spgispeech]: egs/spgispeech/ASR @@ -380,3 +382,4 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [vctk]: egs/vctk/TTS [ljspeech]: egs/ljspeech/TTS +[libritts_tts]: egs/libritts/TTS diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index bf231c5b6..a4f2eb7ab 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -138,7 +138,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=1, + default=5, help="""Save checkpoint after processing this number of epochs" periodically. We save checkpoint to exp-dir/ whenever params.cur_epoch % save_every_n == 0. The checkpoint filename @@ -1093,14 +1093,14 @@ def run(rank, world_size, args): rank=rank, ) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer_g=optimizer_g, - # optimizer_d=optimizer_d, - # params=params, - # ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/libritts/CODEC/prepare.sh b/egs/libritts/CODEC/prepare.sh index 6a471c3ad..da04249ac 100755 --- a/egs/libritts/CODEC/prepare.sh +++ b/egs/libritts/CODEC/prepare.sh @@ -45,12 +45,11 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/LibriTTS mkdir -p data/manifests if [ ! -e data/manifests/.libritts.done ]; then - lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests + lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests touch data/manifests/.libritts.done fi fi - if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Compute Spectrogram for LibriTTS" mkdir -p data/spectrogram @@ -64,7 +63,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c /data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ + <(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz fi diff --git a/egs/libritts/TTS/README.md b/egs/libritts/TTS/README.md new file mode 100644 index 000000000..4d4fb8580 --- /dev/null +++ b/egs/libritts/TTS/README.md @@ -0,0 +1,51 @@ +# Introduction + +LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. +The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. +The main differences from the LibriSpeech corpus are listed below: +1. The audio files are at 24kHz sampling rate. +2. The speech is split at sentence breaks. +3. Both original and normalized texts are included. +4. Contextual information (e.g., neighbouring sentences) can be extracted. +5. Utterances with significant background noise are excluded. +For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. + +> [!CAUTION] +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. +> +> By using this framework, you agree to the following: +> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. +> +> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. +> +> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. +> +> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. + + +# VITS + +This recipe provides a VITS model trained on the LibriTTS dataset. + +Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-libritts-vits-2024-10-30). + +The training command is given below: +``` +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +./vits/train.py \ + --world-size 4 \ + --num-epochs 400 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --max-duration 500 +``` + +To inference, use: +``` +./vits/infer.py \ + --exp-dir vits/exp \ + --epoch 400 \ + --tokens data/tokens.txt +``` diff --git a/egs/libritts/TTS/local/compute_spectrogram_libritts.py b/egs/libritts/TTS/local/compute_spectrogram_libritts.py new file mode 120000 index 000000000..5a6ebba58 --- /dev/null +++ b/egs/libritts/TTS/local/compute_spectrogram_libritts.py @@ -0,0 +1 @@ +../../CODEC/local/compute_spectrogram_libritts.py \ No newline at end of file diff --git a/egs/libritts/TTS/local/prepare_token_file.py b/egs/libritts/TTS/local/prepare_token_file.py new file mode 120000 index 000000000..afc29a22b --- /dev/null +++ b/egs/libritts/TTS/local/prepare_token_file.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/prepare_token_file.py \ No newline at end of file diff --git a/egs/libritts/TTS/local/prepare_tokens_libritts.py b/egs/libritts/TTS/local/prepare_tokens_libritts.py new file mode 100755 index 000000000..faeb611f5 --- /dev/null +++ b/egs/libritts/TTS/local/prepare_tokens_libritts.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# 2024 Tsinghua University (authors: Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest +from piper_phonemize import phonemize_espeak +from tqdm.auto import tqdm + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def prepare_tokens_libritts(): + output_dir = Path("data/spectrogram") + prefix = "libritts" + suffix = "jsonl.gz" + partitions = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-all-shuf", + "train-clean-460", + # "train-clean-100", + # "train-clean-360", + # "train-other-500", + ) + + for partition in partitions: + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + new_cuts = [] + for cut in tqdm(cut_set): + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) + text = cut.supervisions[0].text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens_list = phonemize_espeak(text, "en-us") + tokens = [] + for t in tokens_list: + tokens.extend(t) + cut.tokens = tokens + cut.supervisions[0].normalized_text = remove_punc_to_upper(text) + + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file( + output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_libritts() diff --git a/egs/libritts/TTS/local/validate_manifest.py b/egs/libritts/TTS/local/validate_manifest.py new file mode 120000 index 000000000..b4d52ebca --- /dev/null +++ b/egs/libritts/TTS/local/validate_manifest.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/libritts/TTS/prepare.sh b/egs/libritts/TTS/prepare.sh new file mode 100755 index 000000000..44016e6d2 --- /dev/null +++ b/egs/libritts/TTS/prepare.sh @@ -0,0 +1,134 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 +sampling_rate=24000 +nj=32 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriTTS, + # you can create a symlink + # + # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS + # + if [ ! -d $dl_dir/LibriTTS ]; then + lhotse download libritts $dl_dir + fi + + if [ ! -d $dl_dir/xvector_nnet_1a_libritts_clean_460 ]; then + log "Downloading x-vector" + + git clone https://huggingface.co/datasets/zrjin/xvector_nnet_1a_libritts_clean_460 $dl_dir/xvector_nnet_1a_libritts_clean_460 + + mkdir -p exp/xvector_nnet_1a/ + cp -r $dl_dir/xvector_nnet_1a_libritts_clean_460/* exp/xvector_nnet_1a/ + fi + +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriTTS manifest" + # We assume that you have downloaded the LibriTTS corpus + # to $dl_dir/LibriTTS + mkdir -p data/manifests + if [ ! -e data/manifests/.libritts.done ]; then + lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests + touch data/manifests/.libritts.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute Spectrogram for LibriTTS" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.libritts.done ]; then + ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate + touch data/spectrogram/.libritts.done + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 and + # train-other-500 together to form the training set. + if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 + # together to form the training set. + if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then + cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) | \ + shuf | gzip -c > data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz + fi + + if [ ! -e data/spectrogram/.libritts-validated.done ]; then + log "Validating data/spectrogram for LibriTTS" + ./local/validate_manifest.py \ + data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz + touch data/spectrogram/.libritts-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare phoneme tokens for LibriTTS" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: + # refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - espnet_tts_frontend: + # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/spectrogram/.libritts_with_token.done ]; then + ./local/prepare_tokens_libritts.py + touch data/spectrogram/.libritts_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Generate token file" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: + # refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - espnet_tts_frontend: + # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py --tokens data/tokens.txt + fi +fi diff --git a/egs/libritts/TTS/shared b/egs/libritts/TTS/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/libritts/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/libritts/TTS/vits/duration_predictor.py b/egs/libritts/TTS/vits/duration_predictor.py new file mode 120000 index 000000000..9972b476f --- /dev/null +++ b/egs/libritts/TTS/vits/duration_predictor.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/flow.py b/egs/libritts/TTS/vits/flow.py new file mode 120000 index 000000000..e65d91ea7 --- /dev/null +++ b/egs/libritts/TTS/vits/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/generator.py b/egs/libritts/TTS/vits/generator.py new file mode 120000 index 000000000..611679bfa --- /dev/null +++ b/egs/libritts/TTS/vits/generator.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/hifigan.py b/egs/libritts/TTS/vits/hifigan.py new file mode 120000 index 000000000..5ac025de7 --- /dev/null +++ b/egs/libritts/TTS/vits/hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/infer.py b/egs/libritts/TTS/vits/infer.py new file mode 100755 index 000000000..675678606 --- /dev/null +++ b/egs/libritts/TTS/vits/infer.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List + +import k2 +import numpy as np +import torch +import torch.nn as nn +import torchaudio +from lhotse.features.io import KaldiReader +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LibrittsTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + subset: str, + params: AttributeDict, + model: nn.Module, + tokenizer: Tokenizer, + speaker_map: KaldiReader, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + # Background worker save audios to disk. + def _save_worker( + subset: str, + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), + audio[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"), + audio_pred[i : i + 1, : audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids] + spembs = ( + torch.Tensor(np.array([speaker_map.read(sid) for sid in sids])) + .squeeze(1) + .to(device) + ) + + audio_pred, _, durations = model.inference_batch( + text=tokens, + text_lengths=tokens_lens, + spembs=spembs, + ) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) + + futures.append( + executor.submit( + _save_worker, + subset, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LibrittsTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + # we need cut ids to display recognition results. + args.return_cuts = True + libritts = LibrittsTtsDataModule(args) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to(device) + model.eval() + + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + test_clean_cuts = libritts.test_clean_cuts() + test_clean_speaker_map = libritts.test_clean_xvector() + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + + dev_clean_cuts = libritts.dev_clean_cuts() + dev_clean_speaker_map = libritts.dev_clean_xvector() + dev_clean_dl = libritts.dev_dataloaders(dev_clean_cuts) + + infer_sets = { + "test-clean": (test_clean_dl, test_clean_speaker_map), + "dev-clean": (dev_clean_dl, dev_clean_speaker_map), + } + + for subset, data in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + dl, speaker_map = data + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + tokenizer=tokenizer, + speaker_map=speaker_map, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/TTS/vits/loss.py b/egs/libritts/TTS/vits/loss.py new file mode 120000 index 000000000..672e5ff68 --- /dev/null +++ b/egs/libritts/TTS/vits/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/monotonic_align b/egs/libritts/TTS/vits/monotonic_align new file mode 120000 index 000000000..71934e7cc --- /dev/null +++ b/egs/libritts/TTS/vits/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/libritts/TTS/vits/posterior_encoder.py b/egs/libritts/TTS/vits/posterior_encoder.py new file mode 120000 index 000000000..41d64a3a6 --- /dev/null +++ b/egs/libritts/TTS/vits/posterior_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/residual_coupling.py b/egs/libritts/TTS/vits/residual_coupling.py new file mode 120000 index 000000000..f979adbf0 --- /dev/null +++ b/egs/libritts/TTS/vits/residual_coupling.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/test_onnx.py b/egs/libritts/TTS/vits/test_onnx.py new file mode 100755 index 000000000..ae6587338 --- /dev/null +++ b/egs/libritts/TTS/vits/test_onnx.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# +# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +from pathlib import Path + +import onnxruntime as ort +import torch +import torchaudio +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__( + self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor + ) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), + self.model.get_inputs()[5].name: speaker.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + args.num_spks = len(speaker_map) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids( + [text], intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + speaker = torch.tensor([1], dtype=torch.int64) # (1, ) + audio = model(tokens, tokens_lens, speaker) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libritts/TTS/vits/text_encoder.py b/egs/libritts/TTS/vits/text_encoder.py new file mode 120000 index 000000000..0efba277e --- /dev/null +++ b/egs/libritts/TTS/vits/text_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/tokenizer.py b/egs/libritts/TTS/vits/tokenizer.py new file mode 120000 index 000000000..057b0dc4b --- /dev/null +++ b/egs/libritts/TTS/vits/tokenizer.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/train.py b/egs/libritts/TTS/vits/train.py new file mode 100755 index 000000000..447fbcf5d --- /dev/null +++ b/egs/libritts/TTS/vits/train.py @@ -0,0 +1,1015 @@ +#!/usr/bin/env python3 +# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.features.io import KaldiReader +from lhotse.utils import fix_random_seed +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LibrittsTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 24000, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + generator_params = { + "hidden_channels": 192, + "spks": None, + "langs": None, + "spk_embed_dim": 512, + "global_channels": 256, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + generator_params=generator_params, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input( + batch: dict, + tokenizer: Tokenizer, + device: torch.device, + speaker_map: KaldiReader, +): + """Parse batch data""" + + def parse_sids(batch: dict) -> List[str]: + return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]] + + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + spembs = ( + torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)])) + .squeeze(1) + .to(device) + ) + + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens, spembs + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + dev_dl: torch.utils.data.DataLoader, + train_speaker_map: KaldiReader, + dev_speaker_map: KaldiReader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + spembs, + ) = prepare_input(batch, tokenizer, device, train_speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + dev_dl=dev_dl, + dev_speaker_map=dev_speaker_map, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valid_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valid_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + dev_dl: torch.utils.data.DataLoader, + dev_speaker_map: KaldiReader, + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(dev_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + spembs, + ) = prepare_input(batch, tokenizer, device, dev_speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()], + spembs=spembs[0], + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + train_speaker_map: KaldiReader, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + spembs, + ) = prepare_input(batch, tokenizer, device, train_speaker_map) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + libritts = LibrittsTtsDataModule(args) + + if params.full_libri: + train_cuts = libritts.train_all_shuf_cuts() + train_speaker_map = libritts.train_all_shuf_xvector() + else: + train_cuts = libritts.train_clean_460_cuts() + train_speaker_map = libritts.train_clean_460_xvector() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = libritts.train_dataloaders(train_cuts) + + dev_clean_cuts = libritts.dev_clean_cuts() + dev_speaker_map = libritts.dev_clean_xvector() + dev_dl = libritts.dev_dataloaders(dev_clean_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + train_speaker_map=train_speaker_map, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + dev_dl=dev_dl, + train_speaker_map=train_speaker_map, + dev_speaker_map=dev_speaker_map, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibrittsTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libritts/TTS/vits/transform.py b/egs/libritts/TTS/vits/transform.py new file mode 120000 index 000000000..962647408 --- /dev/null +++ b/egs/libritts/TTS/vits/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/tts_datamodule.py b/egs/libritts/TTS/vits/tts_datamodule.py new file mode 100644 index 000000000..e98e49c1f --- /dev/null +++ b/egs/libritts/TTS/vits/tts_datamodule.py @@ -0,0 +1,432 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.features.io import KaldiReader +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +LIBRITTS_SAMPLING_RATE = 24000 + + +class LibrittsTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=False, + help="""When enabled, use the entire LibriTTS training set. + Otherwise, use the 460h clean subset.""", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speaker-embeds", + type=Path, + default=Path("exp/xvector_nnet_1a/"), + help="Path to directory with speaker embeddings.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = LIBRITTS_SAMPLING_RATE + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = LIBRITTS_SAMPLING_RATE + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + dev_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + dev_dl = DataLoader( + validate, + sampler=dev_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return dev_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = LIBRITTS_SAMPLING_RATE + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def train_clean_460_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100 and train-clean-360 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir + / "libritts_cuts_with_tokens_train-clean-460.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_test-other.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_clean_460_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for train-clean-460") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_train_clean_460" / "feats.scp") + ) + + @lru_cache() + def train_clean_100_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_clean_360_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_other_500_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def dev_clean_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for dev-clean") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_dev_clean" / "feats.scp") + ) + + @lru_cache() + def dev_other_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def test_clean_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for test-clean") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_test_clean" / "feats.scp") + ) + + @lru_cache() + def test_other_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) diff --git a/egs/libritts/TTS/vits/utils.py b/egs/libritts/TTS/vits/utils.py new file mode 120000 index 000000000..085e764b4 --- /dev/null +++ b/egs/libritts/TTS/vits/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/vits.py b/egs/libritts/TTS/vits/vits.py new file mode 120000 index 000000000..1f58cf6fe --- /dev/null +++ b/egs/libritts/TTS/vits/vits.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/wavenet.py b/egs/libritts/TTS/vits/wavenet.py new file mode 120000 index 000000000..28f0a78ee --- /dev/null +++ b/egs/libritts/TTS/vits/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index b9add9e82..521b0121f 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -409,7 +409,12 @@ class VITSGenerator(torch.nn.Module): g = self.global_emb(sids.view(-1)).unsqueeze(-1) if self.spk_embed_dim is not None: # (B, global_channels, 1) - g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if spembs.ndim == 2: + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + elif spembs.ndim == 1: + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + else: + raise ValueError("spembs should be 1D or 2D (batch mode) tensor.") if g is None: g = g_ else: diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 34b943765..184ae79af 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -542,13 +542,13 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) tb_writer.add_audio( - "train/valdi_speech_hat", + "train/valid_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", + "train/valid_speech", speech, params.batch_idx_train, params.sampling_rate, diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index 0b9575cbd..a1fabf9ad 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -622,6 +622,8 @@ class VITS(nn.Module): text: torch.Tensor, text_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, durations: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, @@ -635,6 +637,8 @@ class VITS(nn.Module): text (Tensor): Input text index tensor (B, T_text). text_lengths (Tensor): Input text index tensor (B,). sids (Tensor): Speaker index tensor (B,). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Tensor): Language index tensor (B,). noise_scale (float): Noise scale value for flow. noise_scale_dur (float): Noise scale value for duration predictor. alpha (float): Alpha parameter to control the speed of generated speech. @@ -650,6 +654,8 @@ class VITS(nn.Module): text=text, text_lengths=text_lengths, sids=sids, + spembs=spembs, + lids=lids, noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, alpha=alpha, diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 55bd69327..4686de169 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -597,13 +597,13 @@ def train_one_epoch( tb_writer, "train/valid_", params.batch_idx_train ) tb_writer.add_audio( - "train/valdi_speech_hat", + "train/valid_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", + "train/valid_speech", speech, params.batch_idx_train, params.sampling_rate, From 57451b03828bf5325c91eaf12c80aaae9063283d Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 1 Nov 2024 22:49:19 +0800 Subject: [PATCH 35/59] refactor ksponspeech recipe (#1794) Co-authored-by: Your Name <> --- egs/ksponspeech/ASR/local/__init__.py | 0 .../ASR/local/compute_fbank_musan.py | 159 +----------------- egs/ksponspeech/ASR/local/filter_cuts.py | 158 +---------------- egs/ksponspeech/ASR/local/train_bpe_model.py | 116 +------------ .../ASR/local/validate_manifest.py | 102 +---------- egs/ksponspeech/ASR/zipformer/README.md | 1 - 6 files changed, 4 insertions(+), 532 deletions(-) delete mode 100644 egs/ksponspeech/ASR/local/__init__.py mode change 100755 => 120000 egs/ksponspeech/ASR/local/compute_fbank_musan.py mode change 100644 => 120000 egs/ksponspeech/ASR/local/filter_cuts.py mode change 100755 => 120000 egs/ksponspeech/ASR/local/train_bpe_model.py mode change 100755 => 120000 egs/ksponspeech/ASR/local/validate_manifest.py delete mode 100644 egs/ksponspeech/ASR/zipformer/README.md diff --git a/egs/ksponspeech/ASR/local/__init__.py b/egs/ksponspeech/ASR/local/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/ksponspeech/ASR/local/compute_fbank_musan.py b/egs/ksponspeech/ASR/local/compute_fbank_musan.py deleted file mode 100755 index c0bdacfe5..000000000 --- a/egs/ksponspeech/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory `src_dir` (default is data/manifests). - -The generated fbank features are saved in data/fbank. -""" -import argparse -import logging -import os -from pathlib import Path - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - LilcomChunkyWriter, - MonoCut, - WhisperFbank, - WhisperFbankConfig, - combine, -) -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor, str2bool - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def is_cut_long(c: MonoCut) -> bool: - return c.duration > 5 - - -def compute_fbank_musan( - src_dir: str = "data/manifests", - num_mel_bins: int = 80, - whisper_fbank: bool = False, - output_dir: str = "data/fbank", -): - src_dir = Path(src_dir) - output_dir = Path(output_dir) - num_jobs = min(15, os.cpu_count()) - - dataset_parts = ( - "music", - "speech", - "noise", - ) - prefix = "musan" - suffix = "jsonl.gz" - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - prefix=prefix, - suffix=suffix, - ) - assert manifests is not None - - assert len(manifests) == len(dataset_parts), ( - len(manifests), - len(dataset_parts), - list(manifests.keys()), - dataset_parts, - ) - - musan_cuts_path = output_dir / "musan_cuts.jsonl.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - if whisper_fbank: - extractor = WhisperFbank( - WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") - ) - else: - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) - ) - .cut_into_windows(10.0) - .filter(is_cut_long) - .compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/musan_feats", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomChunkyWriter, - ) - ) - musan_cuts.to_file(musan_cuts_path) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--src-dir", - type=str, - default="data/manifests", - help="Source manifests directory.", - ) - parser.add_argument( - "--num-mel-bins", - type=int, - default=80, - help="""The number of mel bins for Fbank""", - ) - parser.add_argument( - "--whisper-fbank", - type=str2bool, - default=False, - help="Use WhisperFbank instead of Fbank. Default: False.", - ) - parser.add_argument( - "--output-dir", - type=str, - default="data/fbank", - help="Output directory. Default: data/fbank.", - ) - return parser.parse_args() - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - args = get_args() - compute_fbank_musan( - src_dir=args.src_dir, - num_mel_bins=args.num_mel_bins, - whisper_fbank=args.whisper_fbank, - output_dir=args.output_dir, - ) diff --git a/egs/ksponspeech/ASR/local/compute_fbank_musan.py b/egs/ksponspeech/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/ksponspeech/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/local/filter_cuts.py b/egs/ksponspeech/ASR/local/filter_cuts.py deleted file mode 100644 index f081da5df..000000000 --- a/egs/ksponspeech/ASR/local/filter_cuts.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python3 - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script removes short and long utterances from a cutset. - -Caution: - You may need to tune the thresholds for your own dataset. - -Usage example: - - python3 ./local/filter_cuts.py \ - --bpe-model data/lang_bpe_5000/bpe.model \ - --in-cuts data/fbank/speechtools_cuts_test.jsonl.gz \ - --out-cuts data/fbank-filtered/speechtools_cuts_test.jsonl.gz -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -from lhotse import CutSet, load_manifest_lazy -from lhotse.cut import Cut - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "--bpe-model", - type=Path, - help="Path to the bpe.model", - ) - - parser.add_argument( - "--in-cuts", - type=Path, - help="Path to the input cutset", - ) - - parser.add_argument( - "--out-cuts", - type=Path, - help="Path to the output cutset", - ) - - return parser.parse_args() - - -def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): - total = 0 # number of total utterances before removal - removed = 0 # number of removed utterances - - def remove_short_and_long_utterances(c: Cut): - """Return False to exclude the input cut""" - nonlocal removed, total - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ./display_manifest_statistics.py - # - # You should use ./display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - total += 1 - if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - removed += 1 - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./pruned_transducer_stateless2/conformer.py, the - # conv module uses the following expression - # for subsampling - if c.num_frames is None: - num_frames = c.duration * 100 # approximate - else: - num_frames = c.num_frames - - T = ((num_frames - 1) // 2 - 1) // 2 - # Note: for ./lstm_transducer_stateless/lstm.py, the formula is - # T = ((num_frames - 3) // 2 - 1) // 2 - - # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is - # T = ((num_frames - 7) // 2 + 1) // 2 - - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - removed += 1 - return False - - return True - - # We use to_eager() here so that we can print out the value of total - # and removed below. - ans = cut_set.filter(remove_short_and_long_utterances).to_eager() - ratio = removed / total * 100 - logging.info( - f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." - ) - return ans - - -def main(): - args = get_args() - logging.info(vars(args)) - - if args.out_cuts.is_file(): - logging.info(f"{args.out_cuts} already exists - skipping") - return - - assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" - assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" - - sp = spm.SentencePieceProcessor() - sp.load(str(args.bpe_model)) - - cut_set = load_manifest_lazy(args.in_cuts) - assert isinstance(cut_set, CutSet) - - cut_set = filter_cuts(cut_set, sp) - logging.info(f"Saving to {args.out_cuts}") - args.out_cuts.parent.mkdir(parents=True, exist_ok=True) - cut_set.to_file(args.out_cuts) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ksponspeech/ASR/local/filter_cuts.py b/egs/ksponspeech/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/ksponspeech/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/local/train_bpe_model.py b/egs/ksponspeech/ASR/local/train_bpe_model.py deleted file mode 100755 index 5979d5b98..000000000 --- a/egs/ksponspeech/ASR/local/train_bpe_model.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# You can install sentencepiece via: -# -# pip install sentencepiece -# -# Due to an issue reported in -# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 -# -# Please install a version >=0.1.96 - -import argparse -import shutil -from pathlib import Path -from typing import Dict - -import sentencepiece as spm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - The generated bpe.model is saved to this directory. - """, - ) - - parser.add_argument( - "--transcript", - type=str, - help="Training transcript.", - ) - - parser.add_argument( - "--vocab-size", - type=int, - help="Vocabulary size for BPE training", - ) - - return parser.parse_args() - - -def generate_tokens(lang_dir: Path): - """ - Generate the tokens.txt from a bpe model. - """ - sp = spm.SentencePieceProcessor() - sp.load(str(lang_dir / "bpe.model")) - token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} - with open(lang_dir / "tokens.txt", "w", encoding="utf-8") as f: - for sym, i in token2id.items(): - f.write(f"{sym} {i}\n") - - -def main(): - args = get_args() - vocab_size = args.vocab_size - lang_dir = Path(args.lang_dir) - - model_type = "unigram" - - model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" - train_text = args.transcript - character_coverage = 1.0 - input_sentence_size = 100000000 - - user_defined_symbols = ["", ""] - unk_id = len(user_defined_symbols) - # Note: unk_id is fixed to 2. - # If you change it, you should also change other - # places that are using it. - - model_file = Path(model_prefix + ".model") - if not model_file.is_file(): - spm.SentencePieceTrainer.train( - input=train_text, - vocab_size=vocab_size, - model_type=model_type, - model_prefix=model_prefix, - input_sentence_size=input_sentence_size, - character_coverage=character_coverage, - user_defined_symbols=user_defined_symbols, - unk_id=unk_id, - bos_id=-1, - eos_id=-1, - ) - else: - print(f"{model_file} exists - skipping") - return - - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") - - generate_tokens(lang_dir) - - -if __name__ == "__main__": - main() diff --git a/egs/ksponspeech/ASR/local/train_bpe_model.py b/egs/ksponspeech/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/ksponspeech/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/local/validate_manifest.py b/egs/ksponspeech/ASR/local/validate_manifest.py deleted file mode 100755 index 98f273419..000000000 --- a/egs/ksponspeech/ASR/local/validate_manifest.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script checks the following assumptions of the generated manifest: - -- Single supervision per cut -- Supervision time bounds are within cut time bounds - -We will add more checks later if needed. - -Usage example: - - python3 ./local/validate_manifest.py \ - ./data/fbank/speechtools_cuts_train.jsonl.gz - -""" - -import argparse -import logging -from pathlib import Path - -from lhotse import CutSet, load_manifest_lazy -from lhotse.cut import Cut -from lhotse.dataset.speech_recognition import validate_for_asr - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument( - "manifest", - type=Path, - help="Path to the manifest file", - ) - - return parser.parse_args() - - -def validate_one_supervision_per_cut(c: Cut): - if len(c.supervisions) != 1: - raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") - - -def validate_supervision_and_cut_time_bounds(c: Cut): - tol = 2e-3 # same tolerance as in 'validate_for_asr()' - s = c.supervisions[0] - - # Supervision start time is relative to Cut ... - # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html - if s.start < -tol: - raise ValueError( - f"{c.id}: Supervision start time {s.start} must not be negative." - ) - if s.start > tol: - raise ValueError( - f"{c.id}: Supervision start time {s.start} is not at the beginning of the Cut. Please apply `lhotse cut trim-to-supervisions`." - ) - if c.start + s.end > c.end + tol: - raise ValueError( - f"{c.id}: Supervision end time {c.start+s.end} is larger " - f"than cut end time {c.end}" - ) - - -def main(): - args = get_args() - - manifest = args.manifest - logging.info(f"Validating {manifest}") - - assert manifest.is_file(), f"{manifest} does not exist" - cut_set = load_manifest_lazy(manifest) - assert isinstance(cut_set, CutSet) - - for c in cut_set: - validate_one_supervision_per_cut(c) - validate_supervision_and_cut_time_bounds(c) - - # Validation from K2 training - # - checks supervision start is 0 - # - checks supervision.duration is not longer than cut.duration - # - there is tolerance 2ms - validate_for_asr(cut_set) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/ksponspeech/ASR/local/validate_manifest.py b/egs/ksponspeech/ASR/local/validate_manifest.py new file mode 120000 index 000000000..0a9725e87 --- /dev/null +++ b/egs/ksponspeech/ASR/local/validate_manifest.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_manifest.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/zipformer/README.md b/egs/ksponspeech/ASR/zipformer/README.md deleted file mode 100644 index c8c2104cd..000000000 --- a/egs/ksponspeech/ASR/zipformer/README.md +++ /dev/null @@ -1 +0,0 @@ -This recipe implements Zipformer model. From cbe012d54c7007bfbbb82e71a9f1184500fb0824 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 22 Nov 2024 11:18:01 +0800 Subject: [PATCH 36/59] Valle Recipe for WenetSpeech4TTS, LibriTTS, LibriTTS-R (#1805) * add valle * update readme --- egs/libritts/TTS/README.md | 65 +- ...te_neural_codec_and_prepare_text_tokens.py | 1 + egs/libritts/TTS/prepare.sh | 43 +- egs/libritts/TTS/valle | 1 + egs/wenetspeech4tts/TTS/README.md | 72 + ...te_neural_codec_and_prepare_text_tokens.py | 609 ++++++ .../TTS/local/display_manifest_statistics.py | 53 + egs/wenetspeech4tts/TTS/prepare.sh | 100 + egs/wenetspeech4tts/TTS/shared | 1 + ...te_neural_codec_and_prepare_text_tokens.py | 1 + egs/wenetspeech4tts/TTS/valle/infer.py | 300 +++ egs/wenetspeech4tts/TTS/valle/optim.py | 1 + egs/wenetspeech4tts/TTS/valle/tokenizer.py | 111 ++ egs/wenetspeech4tts/TTS/valle/train.py | 1244 ++++++++++++ .../TTS/valle/tts_datamodule.py | 343 ++++ egs/wenetspeech4tts/TTS/valle/valle.py | 1745 +++++++++++++++++ 16 files changed, 4675 insertions(+), 15 deletions(-) create mode 120000 egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py create mode 120000 egs/libritts/TTS/valle create mode 100644 egs/wenetspeech4tts/TTS/README.md create mode 100755 egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py create mode 100755 egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py create mode 100755 egs/wenetspeech4tts/TTS/prepare.sh create mode 120000 egs/wenetspeech4tts/TTS/shared create mode 120000 egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py create mode 100644 egs/wenetspeech4tts/TTS/valle/infer.py create mode 120000 egs/wenetspeech4tts/TTS/valle/optim.py create mode 100644 egs/wenetspeech4tts/TTS/valle/tokenizer.py create mode 100755 egs/wenetspeech4tts/TTS/valle/train.py create mode 100644 egs/wenetspeech4tts/TTS/valle/tts_datamodule.py create mode 100644 egs/wenetspeech4tts/TTS/valle/valle.py diff --git a/egs/libritts/TTS/README.md b/egs/libritts/TTS/README.md index 4d4fb8580..67424a1ca 100644 --- a/egs/libritts/TTS/README.md +++ b/egs/libritts/TTS/README.md @@ -1,7 +1,7 @@ # Introduction -LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. -The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. +LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. +The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. The main differences from the LibriSpeech corpus are listed below: 1. The audio files are at 24kHz sampling rate. 2. The speech is split at sentence breaks. @@ -11,16 +11,16 @@ The main differences from the LibriSpeech corpus are listed below: For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. > [!CAUTION] -> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). > While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. -> +> > By using this framework, you agree to the following: > 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. -> +> > 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. -> +> > 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. -> +> > 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. @@ -49,3 +49,54 @@ To inference, use: --epoch 400 \ --tokens data/tokens.txt ``` + +# [VALL-E](https://arxiv.org/abs/2301.02111) + +./valle contains the code for training VALL-E TTS model. + +Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_libritts). The demo of the model trained with libritts and [libritts-r](https://www.openslr.org/141/) is available [here](https://huggingface.co/spaces/yuekai/valle-libritts-demo). + +Preparation: + +``` +bash prepare.sh --start-stage 4 +``` + +The training command is given below: + +``` +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +``` + +To inference, use: +``` +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_libritts +top_p=1.0 +python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./libritts.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt --top-p ${top_p} +``` diff --git a/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py new file mode 120000 index 000000000..68579ffd4 --- /dev/null +++ b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1 @@ +../../../wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py \ No newline at end of file diff --git a/egs/libritts/TTS/prepare.sh b/egs/libritts/TTS/prepare.sh index 44016e6d2..1700e0737 100755 --- a/egs/libritts/TTS/prepare.sh +++ b/egs/libritts/TTS/prepare.sh @@ -32,7 +32,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then cd vits/monotonic_align python setup.py build_ext --inplace cd ../../ - else + else log "monotonic_align lib already built" fi fi @@ -75,11 +75,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Compute Spectrogram for LibriTTS" mkdir -p data/spectrogram if [ ! -e data/spectrogram/.libritts.done ]; then - ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate + ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate touch data/spectrogram/.libritts.done fi - # Here we shuffle and combine the train-clean-100, train-clean-360 and + # Here we shuffle and combine the train-clean-100, train-clean-360 and # train-other-500 together to form the training set. if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ @@ -88,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz fi - # Here we shuffle and combine the train-clean-100, train-clean-360 + # Here we shuffle and combine the train-clean-100, train-clean-360 # together to form the training set. if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ @@ -108,10 +108,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare phoneme tokens for LibriTTS" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: + # - piper_phonemize: # refer to https://github.com/rhasspy/piper-phonemize, # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: + # - espnet_tts_frontend: # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.libritts_with_token.done ]; then ./local/prepare_tokens_libritts.py @@ -123,12 +123,39 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Generate token file" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: + # - piper_phonemize: # refer to https://github.com/rhasspy/piper-phonemize, # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: + # - espnet_tts_frontend: # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/tokens.txt ]; then ./local/prepare_token_file.py --tokens data/tokens.txt fi fi + +audio_feats_dir=data/tokenized +dataset_parts="--dataset-parts all" # debug "-p dev-clean -p test-clean" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Tokenize/Fbank LibriTTS for valle" + mkdir -p ${audio_feats_dir} + if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --audio-extractor "Encodec" \ + --batch-duration 400 \ + --src-dir "data/manifests" \ + --output-dir "${audio_feats_dir}" + fi + touch ${audio_feats_dir}/.libritts.tokenize.done + + lhotse combine \ + ${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \ + ${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \ + ${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \ + ${audio_feats_dir}/cuts_train.jsonl.gz + lhotse copy \ + ${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \ + ${audio_feats_dir}/cuts_dev.jsonl.gz + lhotse copy \ + ${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \ + ${audio_feats_dir}/cuts_test.jsonl.gz +fi diff --git a/egs/libritts/TTS/valle b/egs/libritts/TTS/valle new file mode 120000 index 000000000..c8fe8fdb0 --- /dev/null +++ b/egs/libritts/TTS/valle @@ -0,0 +1 @@ +../../wenetspeech4tts/TTS/valle/ \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md new file mode 100644 index 000000000..f35bb51c7 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/README.md @@ -0,0 +1,72 @@ +# Introduction + +[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. + +> [!CAUTION] +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. +> +> By using this framework, you agree to the following: +> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. +> +> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. +> +> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. +> +> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. + + +# [VALL-E](https://arxiv.org/abs/2301.02111) + +./valle contains the code for training VALL-E TTS model. + +Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_wenetspeech4tts). The demo of the model trained with Wenetspeech4TTS Premium (945 hours) is available [here](https://huggingface.co/spaces/yuekai/valle_wenetspeech4tts_demo). + +Preparation: + +``` +bash prepare.sh +``` + +The training command is given below: + +``` +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +``` + +To inference, use: +``` +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_wenetspeech4tts +top_p=1.0 +python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./aishell3.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-extractor pypinyin_initials_finals --top-p ${top_p} +``` + +# Credits +- [vall-e](https://github.com/lifeiteng/vall-e) diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py new file mode 100755 index 000000000..5494bf340 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Phonemize Text and EnCodec Audio. + +Usage example: + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --text-extractor ${text_extractor} \ + --audio-extractor ${audio_extractor} \ + --batch-duration 2500 --prefix "wenetspeech4tts" \ + --src-dir "data/manifests" --split 100 \ + --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" + +""" +import argparse +import logging +import os +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.multiprocessing +from encodec import EncodecModel +from encodec.utils import convert_audio +from lhotse import CutSet, NumpyHdf5Writer +from lhotse.features import FeatureExtractor +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.utils import Seconds, compute_num_frames +from phonemizer.backend import EspeakBackend +from phonemizer.backend.espeak.language_switch import LanguageSwitch +from phonemizer.backend.espeak.words_mismatch import WordMismatch +from phonemizer.punctuation import Punctuation +from phonemizer.separator import Separator +from tqdm.auto import tqdm + +from icefall.utils import get_executor + +try: + from pypinyin import Style, pinyin + from pypinyin.style._utils import get_finals, get_initials +except Exception: + pass + + +import re +from typing import Pattern + +import numpy as np +from k2 import SymbolTable + +# from valle.data import ( +# AudioTokenConfig, +# AudioTokenExtractor, +# TextTokenizer, +# tokenize_text, +# ) +# from valle.data.fbank import get_fbank_extractor +# from valle.utils import SymbolTable + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--src-dir", + type=Path, + default=Path("data/manifests"), + help="Path to the manifest files", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to the tokenized files", + ) + parser.add_argument( + "--text-extractor", + type=str, + default="espeak", + help="espeak or pypinyin or pypinyin_initials_finals", + ) + parser.add_argument( + "--audio-extractor", + type=str, + default="Encodec", + help="Encodec or Fbank", + ) + parser.add_argument( + "--dataset-parts", + type=str, + default="dev-clean test-clean", + help="Space separated dataset parts", + ) + parser.add_argument( + "--prefix", + type=str, + default="libritts", + help="prefix of the manifest file", + ) + parser.add_argument( + "--suffix", + type=str, + default="jsonl.gz", + help="suffix of the manifest file", + ) + parser.add_argument( + "--batch-duration", + type=float, + default=400.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + parser.add_argument( + "--split", + type=int, + default=1, + help="Split the cut_set into multiple parts", + ) + + return parser.parse_args() + + +class PypinyinBackend: + """PypinyinBackend for Chinese. Most codes is referenced from espnet. + There are two types pinyin or initials_finals, one is + just like "ni1 hao3", the other is like "n i1 h ao3". + """ + + def __init__( + self, + backend="initials_finals", + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + ) -> None: + self.backend = backend + self.punctuation_marks = punctuation_marks + + def phonemize( + self, text: List[str], separator: Separator, strip=True, njobs=1 + ) -> List[str]: + assert isinstance(text, List) + phonemized = [] + for _text in text: + _text = re.sub(" +", " ", _text.strip()) + _text = _text.replace(" ", separator.word) + phones = [] + if self.backend == "pypinyin": + for n, py in enumerate( + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) + ): + if all([c in self.punctuation_marks for c in py[0]]): + if len(phones): + assert phones[-1] == separator.syllable + phones.pop(-1) + + phones.extend(list(py[0])) + else: + phones.extend([py[0], separator.syllable]) + elif self.backend == "pypinyin_initials_finals": + for n, py in enumerate( + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) + ): + if all([c in self.punctuation_marks for c in py[0]]): + if len(phones): + assert phones[-1] == separator.syllable + phones.pop(-1) + phones.extend(list(py[0])) + else: + if py[0][-1].isalnum(): + initial = get_initials(py[0], strict=False) + if py[0][-1].isdigit(): + final = get_finals(py[0][:-1], strict=False) + py[0][-1] + else: + final = get_finals(py[0], strict=False) + phones.extend( + [ + initial, + separator.phone, + final, + separator.syllable, + ] + ) + else: + assert ValueError + else: + raise NotImplementedError + phonemized.append( + "".join(phones).rstrip(f"{separator.word}{separator.syllable}") + ) + return phonemized + + +class TextTokenizer: + """Phonemize Text.""" + + def __init__( + self, + language="en-us", + backend="espeak", + separator=Separator(word="_", syllable="-", phone="|"), + preserve_punctuation=True, + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + with_stress: bool = False, + tie: Union[bool, str] = False, + language_switch: LanguageSwitch = "keep-flags", + words_mismatch: WordMismatch = "ignore", + ) -> None: + if backend == "espeak": + phonemizer = EspeakBackend( + language, + punctuation_marks=punctuation_marks, + preserve_punctuation=preserve_punctuation, + with_stress=with_stress, + tie=tie, + language_switch=language_switch, + words_mismatch=words_mismatch, + ) + elif backend in ["pypinyin", "pypinyin_initials_finals"]: + phonemizer = PypinyinBackend( + backend=backend, + punctuation_marks=punctuation_marks + separator.word, + ) + else: + raise NotImplementedError(f"{backend}") + + self.backend = phonemizer + self.separator = separator + + def to_list(self, phonemized: str) -> List[str]: + fields = [] + for word in phonemized.split(self.separator.word): + # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. + pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) + fields.extend( + [p for p in pp if p != self.separator.phone] + [self.separator.word] + ) + assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( + self.separator.phone + ) + return fields[:-1] + + def __call__(self, text, strip=True) -> List[List[str]]: + if isinstance(text, str): + text = [text] + + phonemized = self.backend.phonemize( + text, separator=self.separator, strip=strip, njobs=1 + ) + return [self.to_list(p) for p in phonemized] + + +def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: + phonemes = tokenizer([text.strip()]) + return phonemes[0] # k2symbols + + +def remove_encodec_weight_norm(model): + from encodec.modules import SConv1d + from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +class AudioTokenizer: + """EnCodec audio.""" + + def __init__( + self, + device: Any = None, + ) -> None: + # Instantiate a pretrained EnCodec model + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + remove_encodec_weight_norm(model) + + if not device: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + + self._device = device + + self.codec = model.to(device) + self.sample_rate = model.sample_rate + self.channels = model.channels + + @property + def device(self): + return self._device + + def encode(self, wav: torch.Tensor) -> torch.Tensor: + return self.codec.encode(wav.to(self.device)) + + def decode(self, frames: torch.Tensor) -> torch.Tensor: + return self.codec.decode(frames) + + +@dataclass +class AudioTokenConfig: + frame_shift: Seconds = 320.0 / 24000 + num_quantizers: int = 8 + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": + return AudioTokenConfig(**data) + + +class AudioTokenExtractor(FeatureExtractor): + name = "encodec" + config_type = AudioTokenConfig + + def __init__(self, config: Optional[Any] = None): + super(AudioTokenExtractor, self).__init__(config) + self.tokenizer = AudioTokenizer() + + def extract( + self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int + ) -> np.ndarray: + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + if sampling_rate != self.tokenizer.sample_rate: + samples = convert_audio( + samples, + sampling_rate, + self.tokenizer.sample_rate, + self.tokenizer.channels, + ) + if len(samples.shape) == 2: + samples = samples.unsqueeze(0) + else: + raise ValueError() + + device = self.tokenizer.device + encoded_frames = self.tokenizer.encode(samples.detach().to(device)) + codes = encoded_frames[0][0] # [B, n_q, T] + if True: + duration = round(samples.shape[-1] / sampling_rate, ndigits=12) + expected_num_frames = compute_num_frames( + duration=duration, + frame_shift=self.frame_shift, + sampling_rate=sampling_rate, + ) + assert abs(codes.shape[-1] - expected_num_frames) <= 1 + codes = codes[..., :expected_num_frames] + return codes.cpu().squeeze(0).permute(1, 0).numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.frame_shift + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.num_quantizers + + def pad_tensor_list(self, tensor_list, device, padding_value=0): + lengths = [tensor.shape[0] for tensor in tensor_list] + tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] + padded_tensor = torch.nn.utils.rnn.pad_sequence( + tensor_list, batch_first=True, padding_value=padding_value + ) + return padded_tensor, lengths + + def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: + samples = [wav.squeeze() for wav in samples] + device = self.tokenizer.device + samples, lengths = self.pad_tensor_list(samples, device) + samples = samples.unsqueeze(1) + + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + if len(samples.shape) != 3: + raise ValueError() + if sampling_rate != self.tokenizer.sample_rate: + samples = [ + convert_audio( + wav, + sampling_rate, + self.tokenizer.sample_rate, + self.tokenizer.channels, + ) + for wav in samples + ] + samples = torch.stack(samples, 0) # convert samples from list to tensor + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = self.tokenizer.encode(samples.detach().to(device)) + encoded_frames = encoded_frames[0][0] # [B, n_q, T] + batch_codes = [] + for b, length in enumerate(lengths): + codes = encoded_frames[b] + duration = round(length / sampling_rate, ndigits=12) + expected_num_frames = compute_num_frames( + duration=duration, + frame_shift=self.frame_shift, + sampling_rate=sampling_rate, + ) + batch_codes.append(codes[..., :expected_num_frames]) + return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] + + +def main(): + args = get_args() + + dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip() + if dataset_parts == "all": # LibriTTS + dataset_parts = [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ] + else: + dataset_parts = dataset_parts.replace("-p", "").strip().split(" ") + + assert len(dataset_parts) >= 1 + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=args.src_dir, + prefix=args.prefix, + suffix=args.suffix, + types=["recordings", "supervisions", "cuts"], + ) + + text_tokenizer = None + if args.text_extractor: + text_tokenizer = TextTokenizer(backend=args.text_extractor) + + audio_extractor = None + if args.audio_extractor: + if args.audio_extractor == "Encodec": + audio_extractor = AudioTokenExtractor(AudioTokenConfig()) + else: + raise NotImplementedError(f"{args.audio_extractor}") + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + unique_symbols = set() + num_jobs = min(32, os.cpu_count()) + logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}") + + prefix = args.prefix + if prefix and not prefix.endswith("_"): + prefix = f"{prefix}_" + with get_executor() as ex: + for partition, m in manifests.items(): + logging.info( + f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" + ) + try: + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + except Exception: + cut_set = m["cuts"] + + # Split cut_set if split > 1 + split = 1 + if args.split > 1: + cut_sets = cut_set.split(args.split) + split = args.split + else: + cut_sets = [cut_set] + + for idx, part in enumerate(cut_sets): + if args.audio_extractor: + if args.audio_extractor == "Encodec": + storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}" + else: + storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}" + + if args.prefix.lower() in [ + "ljspeech", + "aishell", + "baker", + "wenetspeech4tts", + ]: + part = part.resample(24000) + assert args.prefix.lower() in [ + "ljspeech", + "aishell", + "baker", + "wenetspeech4tts", + "libritts", + "libritts-r", + ] + with torch.no_grad(): + if ( + torch.cuda.is_available() + and args.audio_extractor == "Encodec" + ): + part = part.compute_and_store_features_batch( + extractor=audio_extractor, + storage_path=storage_path, + num_workers=num_jobs, + batch_duration=args.batch_duration, + collate=False, + overwrite=True, + storage_type=NumpyHdf5Writer, + ) + else: + part = part.compute_and_store_features( + extractor=audio_extractor, + storage_path=storage_path, + num_jobs=num_jobs if ex is None else 64, + executor=ex, + storage_type=NumpyHdf5Writer, + ) + + # TextTokenizer + if args.text_extractor: + for c in tqdm(part): + if args.prefix == "ljspeech": + text = c.supervisions[0].custom["normalized_text"] + text = text.replace(""", '"').replace(""", '"') + phonemes = tokenize_text(text_tokenizer, text=text) + elif args.prefix in [ + "aishell", + "aishell2", + "wenetspeech4tts", + "libritts", + "libritts-r", + ]: + phonemes = tokenize_text( + text_tokenizer, text=c.supervisions[0].text + ) + if c.supervisions[0].custom is None: + c.supervisions[0].custom = {} + c.supervisions[0].normalized_text = c.supervisions[0].text + else: + raise NotImplementedError(f"{args.prefix}") + unique_symbols.update(phonemes) + c.tokens = phonemes + assert c.supervisions[ + 0 + ].normalized_text, "normalized_text is None" + + # Save each part with an index if split > 1 + cuts_filename = ( + f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}" + ) + part.to_file(f"{args.output_dir}/{cuts_filename}") + logging.info(f"Saved {cuts_filename}") + + if args.text_extractor: + unique_phonemes = SymbolTable() + for s in sorted(list(unique_symbols)): + unique_phonemes.add(s) + logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}") + + unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols" + unique_phonemes.to_file(unique_phonemes_file) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py b/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py new file mode 100755 index 000000000..f967dfd2b --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 (authors: Feiteng Li) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in the manifests. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. +""" + +import argparse +from pathlib import Path + +from lhotse import load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to the tokenized manifests.", + ) + return parser.parse_args() + + +def main(): + args = get_args() + manifest_dir = args.manifest_dir or Path("data/tokenized") + for part in ["train", "dev", "test"]: + print(f"## {part}") + cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz") + cuts.describe() + print("\n") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh new file mode 100755 index 000000000..54e140dbb --- /dev/null +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -0,0 +1,100 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +stage=1 +stop_stage=4 + +dl_dir=$PWD/download + +dataset_parts="Premium" # Basic for all 10k hours data, Premium for about 10% of the data + +text_extractor="pypinyin_initials_finals" # default is espeak for English +audio_extractor="Encodec" # or Fbank +audio_feats_dir=data/tokenized + +. shared/parse_options.sh || exit 1 + + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "dl_dir: $dl_dir" + log "Stage 0: Download data" + huggingface-cli login + huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS + + # Extract the downloaded data: + for folder in Standard Premium Basic; do + for file in "$dl_dir/$folder"/*.tar.gz; do + tar -xzvf "$file" -C "$dl_dir/$folder" + done + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare wenetspeech4tts manifest" + # We assume that you have downloaded the wenetspeech4tts corpus + # to $dl_dir/wenetspeech4tts + mkdir -p data/manifests + if [ ! -e data/manifests/.wenetspeech4tts.done ]; then + lhotse prepare wenetspeech4tts $dl_dir data/manifests --dataset-parts "${dataset_parts}" + touch data/manifests/.wenetspeech4tts.done + fi +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Tokenize/Fbank wenetspeech4tts" + mkdir -p ${audio_feats_dir} + if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.tokenize.done ]; then + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --text-extractor ${text_extractor} \ + --audio-extractor ${audio_extractor} \ + --batch-duration 2500 --prefix "wenetspeech4tts" \ + --src-dir "data/manifests" \ + --split 100 \ + --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" + cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir} + fi + touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Combine features" + if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz ]; then + pieces=$(find ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100 -name "*.jsonl.gz") + lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare wenetspeech4tts train/dev/test" + if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then + + lhotse subset --first 400 \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_dev.jsonl.gz + + lhotse subset --last 400 \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_test.jsonl.gz + + lhotse copy \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_train.jsonl.gz + + touch ${audio_feats_dir}/.wenetspeech4tts.train.done + fi + python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} +fi diff --git a/egs/wenetspeech4tts/TTS/shared b/egs/wenetspeech4tts/TTS/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py new file mode 120000 index 000000000..e70ee319a --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1 @@ +../local/compute_neural_codec_and_prepare_text_tokens.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/infer.py b/egs/wenetspeech4tts/TTS/valle/infer.py new file mode 100644 index 000000000..fd7ba9f21 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/infer.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script is used to synthesize speech from text prompts and audio prompts. +Usage example: + python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \ + --checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-prompts "KNOT one point one five miles per hour." \ + --audio-prompts ./prompts/8463_294825_000043_000000.wav \ + --text "To get up and running quickly just follow the steps below." + + top_p=1.0 + python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./aishell3.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-extractor pypinyin_initials_finals --top-p ${top_p} + +""" +import argparse +import logging +import os +from pathlib import Path + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +import torch +import torchaudio +from compute_neural_codec_and_prepare_text_tokens import ( + AudioTokenizer, + TextTokenizer, + tokenize_text, +) +from encodec.utils import convert_audio +from k2 import symbol_table +from tokenizer import get_text_token_collater +from valle import VALLE + +from icefall.utils import AttributeDict, str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--text-prompts", + type=str, + default="", + help="Text prompts which are separated by |.", + ) + + parser.add_argument( + "--audio-prompts", + type=str, + default="", + help="Audio prompts which are separated by | and should be aligned with --text-prompts.", + ) + + parser.add_argument( + "--text", + type=str, + default="", + help="prompt text\t prompt audio\ttarget text\ttarget audio", + ) + + parser.add_argument( + "--text-extractor", + type=str, + default="espeak", + help="espeak or pypinyin or pypinyin_initials_finals", + ) + + parser.add_argument( + "--checkpoint", + type=str, + default="exp/vallf_nano_full/checkpoint-100000.pt", + help="Path to the saved checkpoint.", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("infer/demo"), + help="Path to the tokenized files.", + ) + + parser.add_argument( + "--top-k", + type=int, + default=-100, + help="Whether AR Decoder do top_k(if > 0) sampling.", + ) + + parser.add_argument( + "--top-p", + type=float, + default=1.0, + help="Whether AR Decoder do top_p(if > 0) sampling.", + ) + + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="The temperature of AR Decoder top_k sampling.", + ) + + parser.add_argument( + "--continual", + type=str2bool, + default=False, + help="Do continual task.", + ) + + parser.add_argument( + "--repetition-aware-sampling", + type=str2bool, + default=False, + help="Whether AR Decoder do valle-2 repetition-aware sampling. https://arxiv.org/pdf/2406.05370", + ) + + return parser.parse_args() + + +def load_model(checkpoint, device): + if not checkpoint: + return None + + checkpoint = torch.load(checkpoint, map_location=device) + + params = AttributeDict(checkpoint) + model = VALLE( + params.decoder_dim, + params.nhead, + params.num_decoder_layers, + norm_first=params.norm_first, + add_prenet=params.add_prenet, + prefix_mode=params.prefix_mode, + share_embedding=params.share_embedding, + nar_scale_factor=params.scale_factor, + prepend_bos=params.prepend_bos, + num_quantizers=params.num_quantizers, + ) + + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model"], strict=True + ) + assert not missing_keys + model.to(device) + model.eval() + + return model, params.text_tokens + + +def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str): + # Load and pre-process the audio waveform + wav, sr = torchaudio.load(audio_path) + wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) + wav = wav.unsqueeze(0) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = tokenizer.encode(wav) + return encoded_frames + + +@torch.no_grad() +def main(): + args = get_args() + text_tokenizer = TextTokenizer(backend=args.text_extractor) + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + model, text_tokens = load_model(args.checkpoint, device) + + text_collater = get_text_token_collater(text_tokens) + + audio_tokenizer = AudioTokenizer() + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + text_prompts = " ".join(args.text_prompts.split("|")) + + audio_prompts = [] + if args.audio_prompts: + for n, audio_file in enumerate(args.audio_prompts.split("|")): + encoded_frames = tokenize_audio(audio_tokenizer, audio_file) + if False: + samples = audio_tokenizer.decode(encoded_frames) + torchaudio.save(f"{args.output_dir}/p{n}.wav", samples[0], 24000) + + audio_prompts.append(encoded_frames[0][0]) + + assert len(args.text_prompts.split("|")) == len(audio_prompts) + audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) + audio_prompts = audio_prompts.to(device) + + if os.path.isfile(args.text): # for demos + # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py + with open(args.text) as f: + for line in f: + fields = line.strip().split(" ") + fields = [item for item in fields if item] + assert len(fields) == 4 + prompt_text, prompt_audio, text, audio_path = fields + logging.info(f"synthesize text: {text}") + text_tokens, text_tokens_lens = text_collater( + [ + tokenize_text( + text_tokenizer, text=f"{prompt_text} {text}".strip() + ) + ] + ) + _, enroll_x_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{prompt_text}".strip())] + ) + + audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio) + audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device) + + # synthesis + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ras=args.repetition_aware_sampling, + ) + + samples = audio_tokenizer.decode( + [(encoded_frames.transpose(2, 1), None)] + ) + # store + # save audio path into args.output_dir + audio_path + audio_path = f"{args.output_dir}/{audio_path}" + # mkdir -p + os.makedirs(os.path.dirname(audio_path), exist_ok=True) + torchaudio.save(audio_path, samples[0].cpu(), 24000) + return + + for n, text in enumerate(args.text.split("|")): + logging.info(f"synthesize text: {text}") + text_tokens, text_tokens_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{text_prompts} {text}".strip())] + ) + + # synthesis + if args.continual: + assert text == "" + encoded_frames = model.continual( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + ) + else: + enroll_x_lens = None + if text_prompts: + _, enroll_x_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())] + ) + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ras=args.repetition_aware_sampling, + ) + + if audio_prompts != []: + samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)]) + # store + torchaudio.save(f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000) + else: # Transformer + pass + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech4tts/TTS/valle/optim.py b/egs/wenetspeech4tts/TTS/valle/optim.py new file mode 120000 index 000000000..5eaa3cffd --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/tokenizer.py b/egs/wenetspeech4tts/TTS/valle/tokenizer.py new file mode 100644 index 000000000..db4f00396 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/tokenizer.py @@ -0,0 +1,111 @@ +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import torch +from k2 import SymbolTable + + +class TextTokenCollater: + """Collate list of text tokens + + Map sentences to integers. Sentences are padded to equal length. + Beginning and end-of-sequence symbols can be added. + + Example: + >>> token_collater = TextTokenCollater(text_tokens) + >>> tokens_batch, tokens_lens = token_collater(text) + + Returns: + tokens_batch: IntTensor of shape (B, L) + B: batch dimension, number of input sentences + L: length of the longest sentence + tokens_lens: IntTensor of shape (B,) + Length of each sentence after adding and + but before padding. + """ + + def __init__( + self, + text_tokens: List[str], + add_eos: bool = True, + add_bos: bool = True, + pad_symbol: str = "", + bos_symbol: str = "", + eos_symbol: str = "", + ): + self.pad_symbol = pad_symbol + + self.add_eos = add_eos + self.add_bos = add_bos + + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + unique_tokens = ( + [pad_symbol] + + ([bos_symbol] if add_bos else []) + + ([eos_symbol] if add_eos else []) + + sorted(text_tokens) + ) + + self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} + self.idx2token = [token for token in unique_tokens] + + def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + seqs, seq_lens = [], [] + for tokens in tokens_list: + assert all([True if s in self.token2idx else False for s in tokens]) is True + seq = ( + ([self.bos_symbol] if self.add_bos else []) + + list(tokens) + + ([self.eos_symbol] if self.add_eos else []) + ) + seqs.append(seq) + seq_lens.append(len(seq)) + + max_len = max(seq_lens) + for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): + seq.extend([self.pad_symbol] * (max_len - seq_len)) + + tokens = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + tokens_lens = torch.IntTensor(seq_lens) + + return tokens, tokens_lens + + def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + tokens_seqs = [[p for p in text] for text in texts] + max_len = len(max(tokens_seqs, key=len)) + + seqs = [ + ([self.bos_symbol] if self.add_bos else []) + + list(seq) + + ([self.eos_symbol] if self.add_eos else []) + + [self.pad_symbol] * (max_len - len(seq)) + for seq in tokens_seqs + ] + + tokens_batch = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + + tokens_lens = torch.IntTensor( + [len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs] + ) + + return tokens_batch, tokens_lens + + +def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: + text_tokens_path = Path(text_tokens_file) + unique_tokens = SymbolTable.from_file(text_tokens_path) + collater = TextTokenCollater(unique_tokens.symbols, add_bos=True, add_eos=True) + return collater diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py new file mode 100755 index 000000000..fde209511 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -0,0 +1,1244 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +""" + +import argparse +import copy +import logging +import os +import random +import warnings +from contextlib import nullcontext +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from tokenizer import TextTokenCollater, get_text_token_collater +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import TtsDataModule +from valle import VALLE + +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=1024, + help="Embedding dimension in the decoder model.", + ) + parser.add_argument( + "--nhead", + type=int, + default=16, + help="Number of attention heads in the Decoder layers.", + ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=12, + help="Number of Decoder layers.", + ) + parser.add_argument( + "--scale-factor", + type=float, + default=1.0, + help="Model scale factor which will be assigned different meanings in different models.", + ) + parser.add_argument( + "--norm-first", + type=str2bool, + default=True, + help="Pre or Post Normalization.", + ) + parser.add_argument( + "--add-prenet", + type=str2bool, + default=False, + help="Whether add PreNet after Inputs.", + ) + + parser.add_argument( + "--prefix-mode", + type=int, + default=0, + help="The mode for how to prefix VALL-E NAR Decoder, " + "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", + ) + parser.add_argument( + "--share-embedding", + type=str2bool, + default=True, + help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", + ) + parser.add_argument( + "--prepend-bos", + type=str2bool, + default=False, + help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", + ) + parser.add_argument( + "--num-quantizers", + type=int, + default=8, + help="Number of Audio/Semantic quantization layers.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="exp/valle_dev", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--text-tokens", + type=str, + default="data/tokenized/unique_text_tokens.k2symbols", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--optimizer-name", + type=str, + default="ScaledAdam", + help="The optimizer.", + ) + parser.add_argument( + "--scheduler-name", + type=str, + default="Eden", + help="The scheduler.", + ) + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=200, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train %% save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + parser.add_argument( + "--valid-interval", + type=int, + default=10000, + help="""Run validation if batch_idx %% valid_interval is 0.""", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=0, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accumulate-grad-steps", + type=int, + default=1, + help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. + """, + ) + + parser.add_argument( + "--dtype", + type=str, + default="float32", + help="Training dtype: float32 bfloat16 float16.", + ) + + parser.add_argument( + "--filter-min-duration", + type=float, + default=0.0, + help="Keep only utterances with duration > this.", + ) + parser.add_argument( + "--filter-max-duration", + type=float, + default=20.0, + help="Keep only utterances with duration < this.", + ) + + parser.add_argument( + "--train-stage", + type=int, + default=0, + help="""0: train all modules, For VALL-E, support 1: AR Decoder 2: NAR Decoder(s) + """, + ) + + parser.add_argument( + "--visualize", + type=str2bool, + default=False, + help="visualize model results in eval step.", + ) + + parser.add_argument( + "--oom-check", + type=str2bool, + default=False, + help="perform OOM check on dataloader batches before starting training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + if isinstance(model, DDP): + raise ValueError("load_checkpoint before DDP") + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + saved_stage = saved_params.get("train_stage", 0) + if params.train_stage != saved_stage: + # switch training stage + if params.train_stage and saved_stage: # switch between 1 and 2 + params.start_epoch = 1 + params.start_batch = 0 + else: + # switch between 0 and 1/2 + assert params.num_epochs >= params.start_epoch + params.batch_idx_train = saved_params["batch_idx_train"] + + for key in ["optimizer", "grad_scaler", "sampler"]: + if key in saved_params: + saved_params.pop(key) + + # when base on stage 0, we keep scheduler + if saved_stage != 0: + for key in ["scheduler"]: + if key in saved_params: + saved_params.pop(key) + + best_train_filename = params.exp_dir / "best-train-loss.pt" + if best_train_filename.is_file(): + copyfile( + src=best_train_filename, + dst=params.exp_dir / f"best-train-loss-stage{saved_stage}.pt", + ) + + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + if best_valid_filename.is_file(): + copyfile( + src=best_valid_filename, + dst=params.exp_dir / f"best-valid-loss-stage{saved_stage}.pt", + ) + else: + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def prepare_input(batch: dict, tokenizer: TextTokenCollater, device: torch.device): + """Parse batch data""" + + features = batch["features"].to(device) + features_lens = batch["features_lens"].to(device) + if "tokens" not in batch: + raise NotImplementedError("Need to tokenize text") + # tokens = [] + # for c in batch["cuts"]: + # phonemes = tokenize_text( + # tokenizer, text=c.supervisions[0].text + # ) + # tokens.append(phonemes) + else: + tokens = batch["tokens"] + + text_tokens, text_tokens_lens = tokenizer(tokens) + text_tokens = text_tokens.to(device) + text_tokens_lens = text_tokens_lens.to(device) + + return features, features_lens, text_tokens, text_tokens_lens + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + ( + audio_features, + audio_features_lens, + text_tokens, + text_tokens_lens, + ) = prepare_input(batch, tokenizer, device) + # at entry, TextTokens is (N, P) + assert text_tokens.ndim == 2 + assert audio_features.ndim == 3 + + with torch.set_grad_enabled(is_training): + predicts, loss, metrics = model( + x=text_tokens, + x_lens=text_tokens_lens, + y=audio_features, + y_lens=audio_features_lens, + train_stage=params.train_stage, + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (audio_features_lens).sum().item() + info["utterances"] = text_tokens.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + for metric in metrics: + info[metric] = metrics[metric].detach().cpu().item() + del metrics + + return predicts, loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + predicts, loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + if world_size > 1: + tot_loss.reduce(loss.device) + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + if params.visualize: + output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") + output_dir.mkdir(parents=True, exist_ok=True) + if isinstance(model, DDP): + model.module.visualize(predicts, batch, output_dir=output_dir) + else: + model.visualize(predicts, batch, output_dir=output_dir) + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + Random for selecting. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + tot_loss = MetricsTracker() + iter_dl = iter(train_dl) + + dtype, enabled = torch.float32, False + if params.dtype in ["bfloat16", "bf16"]: + dtype, enabled = torch.bfloat16, True + elif params.dtype in ["float16", "fp16"]: + dtype, enabled = torch.float16, True + + batch_idx = 0 + while True: + try: + batch = next(iter_dl) + except StopIteration: + logging.info("Reaches end of dataloader.") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["text"]) + + try: + with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): + _, loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( + 1 / params.reset_interval + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + scaler.scale(loss).backward() + if params.batch_idx_train >= params.accumulate_grad_steps: + if params.batch_idx_train % params.accumulate_grad_steps == 0: + if params.optimizer_name not in ["ScaledAdam", "Eve"]: + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(optimizer) + # Since the gradients of optimizer's assigned params are unscaled, clips as usual: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + for k in range(params.accumulate_grad_steps): + if isinstance(scheduler, Eden): + scheduler.step_batch(params.batch_idx_train) + else: + scheduler.step() + + set_batch_count(model, params.batch_idx_train) + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.average_period > 0: + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = ( + scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, train_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + + ( + f", grad_scale: {cur_grad_scale}" + if params.dtype in ["float16", "fp16"] + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, + "train/current_", + params.batch_idx_train, + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.dtype in ["float16", "fp16"]: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if params.batch_idx_train % params.valid_interval == 0: + # Calculate validation loss in Rank 0 + model.eval() + logging.info("Computing validation loss") + with torch.cuda.amp.autocast(dtype=dtype): + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + model.train() + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, min_duration: float, max_duration: float +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 0.6 second and 20 seconds + if c.duration < min_duration or c.duration > max_duration: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + if params.train_stage: + tb_writer = SummaryWriter( + log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}" + ) + else: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info(f"Device: {device}") + + tokenizer = get_text_token_collater(params.text_tokens) + logging.info(params) + + logging.info("About to create model") + + model = VALLE( + params.decoder_dim, + params.nhead, + params.num_decoder_layers, + norm_first=params.norm_first, + add_prenet=params.add_prenet, + prefix_mode=params.prefix_mode, + share_embedding=params.share_embedding, + nar_scale_factor=params.scale_factor, + prepend_bos=params.prepend_bos, + num_quantizers=params.num_quantizers, + ) + + with open(f"{params.exp_dir}/model.txt", "w") as f: + print(model) + print(model, file=f) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0 and params.average_period > 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.train_stage: + _model = model.module if isinstance(model, DDP) else model + model_parameters = _model.stage_parameters(params.train_stage) + else: + model_parameters = model.parameters() + + if params.optimizer_name == "ScaledAdam": + optimizer = ScaledAdam( + model_parameters, + lr=params.base_lr, + clipping_scale=2.0, + ) + elif params.optimizer_name == "AdamW": + optimizer = torch.optim.AdamW( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + weight_decay=1e-2, + eps=1e-8, + ) + elif params.optimizer_name == "Adam": + optimizer = torch.optim.Adam( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + eps=1e-8, + ) + else: + raise NotImplementedError() + + scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) + optimizer.zero_grad() + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.inf_check: + register_inf_check_hooks(model) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + dataset = TtsDataModule(args) + train_cuts = dataset.train_cuts() + valid_cuts = dataset.dev_cuts() + + train_cuts = filter_short_and_long_utterances( + train_cuts, params.filter_min_duration, params.filter_max_duration + ) + valid_cuts = filter_short_and_long_utterances( + valid_cuts, params.filter_min_duration, params.filter_max_duration + ) + + train_dl = dataset.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + valid_dl = dataset.dev_dataloaders(valid_cuts) + + if params.oom_check: + scan_pessimistic_batches_for_oom( + model=model, + tokenizer=tokenizer, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + if isinstance(scheduler, Eden): + scheduler.step_epoch(epoch - 1) + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + + dtype = torch.float32 + if params.dtype in ["bfloat16", "bf16"]: + dtype = torch.bfloat16 + elif params.dtype in ["float16", "fp16"]: + dtype = torch.float16 + + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(dtype=dtype): + _, loss, _ = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py b/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py new file mode 100644 index 000000000..8e34d06dc --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py @@ -0,0 +1,343 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (Author: Yuekai Zhang) +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.features.io import KaldiReader +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class TtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in TTS + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speaker-embeds", + type=Path, + default=Path("exp/xvector_nnet_1a/"), + help="Path to directory with speaker embeddings.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=4, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=False, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--dataset", + type=str, + default="libritts", + help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Audio sampling rate.""", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + raise NotImplementedError + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + dev_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + dev_dl = DataLoader( + validate, + sampler=dev_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return dev_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz") + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py new file mode 100644 index 000000000..b2eb8ae69 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -0,0 +1,1745 @@ +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import numbers +import random +from functools import partial +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn import functional as F +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter +from torchmetrics.classification import MulticlassAccuracy + +from icefall.utils import make_pad_mask + +NUM_TEXT_TOKENS = 5000 +NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins + + +class PromptedFeatures: + def __init__(self, prompts, features): + self.prompts = prompts + self.features = features + + def to(self, device): + return PromptedFeatures(self.prompts.to(device), self.features.to(device)) + + def sum(self): + return self.features.sum() + + @property + def ndim(self): + return self.features.ndim + + @property + def data(self): + return (self.prompts, self.features) + + +class TokenEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.dim_model = dim_model + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + X = self.word_embeddings(x) + X = self.dropout(X) + + return X + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.dim_model = dim_model + self.x_scale = math.sqrt(dim_model) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + + self.reverse = False + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, 4000)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.dim_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.dim_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.dim_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype).detach() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(output) + + +class Transpose(nn.Identity): + """(N, T, D) -> (N, D, T)""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.transpose(1, 2) + + +_shape_t = Union[int, List[int], torch.Size] + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``forward()`` will use a special optimized implementation if all of the following + conditions are met: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This + restriction will be loosened in the future.) + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - dropout is 0 + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - at most one of ``key_padding_mask`` or ``attn_mask`` is passed + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + + If the optimized implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + """ + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point( + key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + why_not_fast_path = "" + if not is_batched: + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif ( + self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype + ): + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.dropout: + why_not_fast_path = f"dropout was {self.dropout}, required zero" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = ( + "key_padding_mask is not supported with NestedTensor input" + ) + elif self.num_heads % 2 == 1: + why_not_fast_path = "num_heads is odd" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all( + [ + (x is None or x.is_cuda or "cpu" in str(x.device)) + for x in tensor_args + ] + ): + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any( + [x is not None and x.requires_grad for x in tensor_args] + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + key_padding_mask if key_padding_mask is not None else attn_mask, + need_weights, + average_attn_weights, + 1 + if key_padding_mask is not None + else 0 + if attn_mask is not None + else None, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + # elif activation == BalancedDoubleSwish: + # activation = BalancedDoubleSwish(d_model) + + # # We can't test self.activation in forward() in TorchScript, + # # so stash some information about it instead. + # if activation is F.relu or isinstance(activation, torch.nn.ReLU): + # self.activation_relu_or_gelu = 1 + # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + # self.activation_relu_or_gelu = 2 + # else: + # self.activation_relu_or_gelu = 0 + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + # if layer_norm_cls == IdentityNorm: + # norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + # else: + if True: + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, + ) + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + if is_src_tuple: + return (x, stage_embedding) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + return_layer_states: return layers' state (optional). + + Shape: + see the docs in Transformer class. + """ + if return_layer_states: + layer_states = [] # layers' output + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + layer_states.append(output[0]) + + if self.norm is not None: + output = self.norm(output) + + return layer_states, output + + output = src + for mod in self.layers: + output = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +class VALLE(nn.Module): + """It implements https://arxiv.org/abs/2301.02111 + "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" + """ + + def __init__( + self, + d_model: int, + nhead: int, + num_layers: int, + norm_first: bool = True, + add_prenet: bool = False, + decoder_cls=TransformerEncoder, + decoder_layer_cls=TransformerEncoderLayer, + prefix_mode: int = 0, + share_embedding: bool = True, + nar_scale_factor: float = 1.0, + prepend_bos: bool = False, + num_quantizers: int = 8, + **kwargs, + ): + """ + Args: + d_model: + The number of expected features in the input (required). + nhead: + The number of heads in the multiheadattention models (required). + num_layers: + The number of sub-decoder-layers in the decoder (required). + """ + super().__init__() + nar_d_model = int(d_model * nar_scale_factor) + + self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x + self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) + + # ID NUM_AUDIO_TOKENS -> PAD + # ID NUM_AUDIO_TOKENS + 1 -> BOS + self.ar_audio_prepend_bos = prepend_bos + self.ar_audio_embedding = TokenEmbedding( + d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) + ) + + # PreNet + if add_prenet: + self.ar_text_prenet = nn.Sequential( + Transpose(), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + Transpose(), + nn.Linear(d_model, d_model), + ) + + self.ar_audio_prenet = nn.Sequential( + nn.Linear(d_model, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, d_model), + ) + else: + self.ar_text_prenet = nn.Identity() + self.ar_audio_prenet = nn.Identity() + + self.ar_text_position = SinePositionalEmbedding( + d_model, + dropout=0.1, + scale=False, + alpha=True, + ) + self.ar_audio_position = SinePositionalEmbedding( + d_model, + dropout=0.1, + scale=False, + alpha=True, + ) + + self.ar_decoder = decoder_cls( + decoder_layer_cls( + d_model, + nhead, + dim_feedforward=d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=num_layers, + norm=LayerNorm(d_model) if norm_first else None, + ) + self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False) + + self.ar_accuracy_metric = MulticlassAccuracy( + NUM_AUDIO_TOKENS + 1, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=NUM_AUDIO_TOKENS, + ) + + self.rng = random.Random(0) + self.num_heads = nhead + self.prefix_mode = prefix_mode + self.num_quantizers = num_quantizers + + assert num_quantizers >= 1 + if num_quantizers > 1: + self.nar_audio_embeddings = nn.ModuleList( + [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] + + [ + TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) + for i in range(num_quantizers - 1) + ] + ) # W_a + + # PreNet + if add_prenet: + self.nar_text_prenet = nn.Sequential( + Transpose(), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + Transpose(), + nn.Linear(nar_d_model, nar_d_model), + ) + self.nar_audio_prenet = nn.Sequential( + nn.Linear(nar_d_model, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, nar_d_model), + ) + else: + self.nar_text_prenet = nn.Identity() + self.nar_audio_prenet = nn.Identity() + + self.nar_text_position = SinePositionalEmbedding( + nar_d_model, + dropout=0.0, + scale=False, + alpha=False, + ) + self.nar_audio_position = SinePositionalEmbedding( + nar_d_model, + dropout=0.1, + scale=False, + alpha=False, + ) + + self.nar_decoder = decoder_cls( + decoder_layer_cls( + nar_d_model, + int(nhead * nar_scale_factor), + dim_feedforward=nar_d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + adaptive_layer_norm=True, + ), + num_layers=int(num_layers * nar_scale_factor), + norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model)) + if norm_first + else None, + ) + self.nar_predict_layers = nn.ModuleList( + [ + nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) + for i in range(num_quantizers - 1) + ] + ) + self.nar_stage_embeddings = nn.ModuleList( + [TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)] + ) + + if share_embedding: + # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa + # NOTE(Feiteng): In the experiment, this undermines accuracy + # self.ar_predict_layer.weight = self.ar_audio_embedding.weight + + # We also share the parameters of the acoustic embedding layer and the output prediction layer, + # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. + for j in range(0, num_quantizers - 2): + self.nar_predict_layers[j].weight = self.nar_audio_embeddings[ + j + 2 + ].weight + + self.nar_accuracy_metric = MulticlassAccuracy( + NUM_AUDIO_TOKENS + 1, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=NUM_AUDIO_TOKENS, + ) + + def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: + assert stage > 0 + if stage == 1: + for name, param in self.named_parameters(): + if name.startswith("ar_"): + print(f" AR parameter: {name}") + yield param + + if stage == 2: + for name, param in self.named_parameters(): + if name.startswith("nar_"): + print(f"NAR parameter: {name}") + yield param + + def stage_named_parameters( + self, stage: int = 1 + ) -> Iterator[Tuple[str, nn.Parameter]]: + assert stage > 0 + if stage == 1: + for pair in self.named_parameters(): + if pair[0].startswith("ar_"): + yield pair + + if stage == 2: + for pair in self.named_parameters(): + if pair[0].startswith("nar_"): + yield pair + + def pad_y_eos(self, y, y_mask_int, eos_id): + targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( + y_mask_int, (0, 1), value=1 + ) + # inputs, targets + if self.ar_audio_prepend_bos: + return ( + F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), + targets, + ) + + return targets[:, :-1], targets[:, 1:] + + def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): + # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds + # from the same utterance. + # We implement this differently. + if self.prefix_mode == 0: + # no prefix + prefix_len = 0 + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, nar_stage): + # Formula (4) (5) + y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) + elif self.prefix_mode == 1: + # prefix at begining + int_low = (0.25 * y_lens.min()).type(torch.int64).item() + prefix_len = torch.randint(int_low, int_low * 2, size=()).item() + prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames + + y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) + y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + elif self.prefix_mode in [2, 4]: + if self.prefix_mode == 2: + # random prefix + prefix_len = min(225, int(0.25 * y_lens.min().item())) + + y_prompts_codes = [] + for b in range(codes.shape[0]): + start = self.rng.randint(0, y_lens[b].item() - prefix_len) + y_prompts_codes.append( + torch.clone(codes[b, start : start + prefix_len]) + ) + codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS + y_prompts_codes = torch.stack(y_prompts_codes, dim=0) + else: + prefix_len = y_prompts_codes.shape[1] + + y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[..., j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + else: + raise ValueError + + return y_emb, prefix_len + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: Union[torch.Tensor, PromptedFeatures], + y_lens: Union[torch.Tensor, PromptedFeatures], + reduction: str = "sum", + train_stage: int = 0, + **kwargs, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """ + Args: + x: + A 2-D tensor of shape (N, S). + x_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (N, T, 8). + y_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + train_stage: + 0: AR & NAR modules, 1: AR modules, 2: NAR modules + Returns: + Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + + y_prompts_codes = None + if isinstance(y, PromptedFeatures): + y_prompts_codes, y = y.data + prompts_len, y_lens = y_lens.data + assert prompts_len.min() == prompts_len.max() + assert self.prefix_mode == 4 + y_prompts_codes = y_prompts_codes.type(torch.int64) + + assert y.ndim == 3, y.shape + assert y_lens.ndim == 1, y_lens.shape + + # NOTE: x has been padded in TextTokenCollater + x_mask = make_pad_mask(x_lens).to(x.device) + y_mask = make_pad_mask(y_lens).to(y.device) + y_mask_int = y_mask.type(torch.int64) + + text = x + codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) + + y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS) + + x_len = x_lens.max() + + metrics = {} + total_loss = 0.0 + + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + if self.ar_audio_prepend_bos: + ar_xy_padding_mask = torch.concat( + [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 + ) + else: + ar_xy_padding_mask = xy_padding_mask + # AR Decoder + if train_stage in [0, 1]: + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + y_len = y_lens.max() + int(self.ar_audio_prepend_bos) + + x_attn_mask = F.pad( + torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), + diagonal=1, + ), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + + # merge key padding and attention masks + bsz, src_len = x.shape[0], x_len + y_len + _xy_padding_mask = ( + ar_xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_heads, -1, -1) + .reshape(bsz * self.num_heads, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + + y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y_emb) + + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.ar_decoder( + (xy_pos, None), + mask=xy_attn_mask, + # src_key_padding_mask=xy_padding_mask, + # is_causal=True, + ) + logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) + # loss + total_loss = F.cross_entropy(logits, targets, reduction=reduction) + + metrics["ArTop10Accuracy"] = self.ar_accuracy_metric( + logits.detach(), targets + ).item() * y_lens.sum().type(torch.float32) + + if self.num_quantizers == 1: + return ((x, codes), total_loss, metrics) + + # Non-AR Decoders + if self.ar_audio_prepend_bos: + y = y[:, 1:] + if train_stage in [0, 2]: + num_nar_layers = self.num_quantizers - 1 + nar_stage = self.rng.choices( + [_k for _k in range(1, self.num_quantizers)], + weights=[1.0 / num_nar_layers] * num_nar_layers, + k=1, + )[0] + + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + y_emb, prefix_len = self._prepare_prompts( + y, y_lens, codes, nar_stage, y_prompts_codes + ) + + y_len = y_lens.max() + targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int + if self.prefix_mode in [2, 4]: + xy_padding_mask = torch.concat( + [ + x_mask, + F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), + ], + dim=1, + ) + elif self.prefix_mode == 1: + targets = targets[:, prefix_len:] + + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), + src_key_padding_mask=xy_padding_mask, + # is_causal=False, + ) + xy_dec = xy_dec[:, x_lens.max() + prefix_len :] + if self.prefix_mode == 4: + prefix_len = 0 # reset for Top10Accuracy metric + logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1) + + # loss + total_length = (y_lens).sum().type(torch.float32) + total_loss += F.cross_entropy( + logits, + targets, + ignore_index=NUM_AUDIO_TOKENS, + reduction=reduction, + ) * (total_length / (total_length - prefix_len * x.shape[0])) + metrics["NarTop10Accuracy"] = ( + self.nar_accuracy_metric( + F.pad( + logits.detach(), + (0, 0, 0, 1, 0, 0), + value=logits.min().cpu().item(), + ), + targets, + ).item() + * total_length + ) + + if train_stage == 0: + total_loss = total_loss / 2.0 + + return ((x, codes), total_loss, metrics) + + def inference( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + enroll_x_lens: torch.Tensor, + top_k: int = -100, + temperature: float = 1.0, + top_p: float = 1.0, + ras: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (1, S). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, 8). + top_k: (`optional`) int + The number of highest probability tokens to keep for top-k-filtering. Default to -100. + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + ras: (`optional`) bool + Whether to use repetition-aware sampling. Default to False. + Returns: + Return the predicted audio code matrix. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + assert y.shape[0] == 1, y.shape + + assert torch.all(x_lens > 0) + + # NOTE: x has been padded in TextTokenCollater + text = x + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + text_len = x_lens.max() + prompts = y + prefix_len = y.shape[1] + + # AR Decoder + # TODO: Managing decoder steps avoid repetitive computation + y = prompts[..., 0] + if self.ar_audio_prepend_bos: + y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) + + x_len = x_lens.max() + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + while True: + y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + y_len = y.shape[1] + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( + y.device + ) + + xy_dec, _ = self.ar_decoder( + (xy_pos, None), + mask=xy_attn_mask, + ) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = topk_sampling( + logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_aware_sampling=ras, + preceding_tokens=y, + ) + + if ( + torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS + or samples[0, 0] == NUM_AUDIO_TOKENS + or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 + ): + if prompts.shape[1] == y.shape[1]: + raise SyntaxError("well trained model shouldn't reach here.") + break + + y = torch.concat([y, samples], dim=1) + + codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] + if self.num_quantizers == 1: + return torch.stack(codes, dim=-1) + + # Non-AR Decoders + y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :]) + + if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes + enrolled_len = enroll_x_lens.max().item() + # SOS + Synthesis Text + EOS + text = torch.concat( + [ + text[:, :1], + text[:, enrolled_len - 1 :], + ], + dim=1, + ) + text_len = text_len - (enrolled_len - 2) + assert text.shape[0] == 1 + + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + if self.prefix_mode == 0: + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) + y_emb[:, prefix_len:] += embedding_layer(samples) + else: + for j in range(1, self.num_quantizers): + y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) + + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, prefix_len:] += embedding_layer(samples) + + assert len(codes) == self.num_quantizers + return torch.stack(codes, dim=-1) + + def continual( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (1, S). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, 8). + Returns: + Return the predicted audio code matrix. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + assert y.shape[0] == 1, y.shape + + assert torch.all(x_lens > 0) + assert self.num_quantizers == 8 + + # NOTE: x has been padded in TextTokenCollater + text = x + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + text_len = x_lens.max() + + prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) + + # AR Decoder + prompts = y[:, :prefix_len] + + codes = [y[:, prefix_len:, 0]] + # Non-AR Decoders + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + y_emb = self.nar_audio_embeddings[0](y[..., 0]) + + if self.prefix_mode == 0: + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_position(y_emb) + y_pos = self.nar_audio_prenet(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < 6: + y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) + y_emb[:, prefix_len:] += embedding_layer(samples) + else: + for j in range(1, 8): + y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) + + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < 6: + y_emb[:, prefix_len:] += embedding_layer(samples) + + assert len(codes) == 8 + return torch.stack(codes, dim=-1) + + +# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def topk_sampling( + logits, + top_k=10, + top_p=1.0, + temperature=1.0, + repetition_aware_sampling=False, + preceding_tokens=None, +): + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits_filtered = top_k_top_p_filtering( + logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) + # Sample + probs = F.softmax(logits_filtered, dim=-1) + tokens = torch.multinomial(probs, num_samples=1) + + if repetition_aware_sampling: + window_size = 10 + threshold = 0.1 + # we first generate the target code ct′ + # by nucleus sampling with a pre-defined top-p value v. Then, we + # calculate the repetition ratio r of token ct′ + # in the preceding code sequence with a window size K. + # If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct′ + # by + # random sampling from p(ct′ + # |x, c window_size: + preceding_tokens = preceding_tokens[:, -window_size:] + if preceding_tokens.shape[1] > 0: + for i, item in enumerate(preceding_tokens): + # check if the repeat ratio exceeds the threshold + if (item == tokens[i]).sum() / window_size > threshold: + # replace the target code ct′ by random sampling + probs = F.softmax(logits[i], dim=-1) + token_new = torch.multinomial(probs, num_samples=1) + tokens[i] = token_new + return tokens From 18fa6a0fecb16c4b825f87e5ace06e7e84b3a3ff Mon Sep 17 00:00:00 2001 From: Han Zhu Date: Fri, 29 Nov 2024 11:45:05 +0800 Subject: [PATCH 37/59] Fix LibriTTS prepare.sh (#1815) --- egs/libritts/TTS/prepare.sh | 2 +- egs/ljspeech/TTS/local/validate_manifest.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/libritts/TTS/prepare.sh b/egs/libritts/TTS/prepare.sh index 1700e0737..a0a6d2ae1 100755 --- a/egs/libritts/TTS/prepare.sh +++ b/egs/libritts/TTS/prepare.sh @@ -84,7 +84,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ + <(gunzip -c data/spectrogram/libritts_cuts_train-other-500.jsonl.gz) | \ shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz fi diff --git a/egs/ljspeech/TTS/local/validate_manifest.py b/egs/ljspeech/TTS/local/validate_manifest.py index 9535ba9f4..68159ae03 100755 --- a/egs/ljspeech/TTS/local/validate_manifest.py +++ b/egs/ljspeech/TTS/local/validate_manifest.py @@ -33,7 +33,6 @@ import argparse import logging from pathlib import Path -from compute_fbank_ljspeech import MyFbank from lhotse import CutSet, load_manifest_lazy from lhotse.dataset.speech_synthesis import validate_for_tts From a1ade8ecb77a78ff55d1bf41a918051b5f306731 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 29 Nov 2024 16:36:02 +0800 Subject: [PATCH 38/59] fixed failed assertion in the `xbmu_ambo31` recipe (#1816) --- .../ASR/pruned_transducer_stateless7/train.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index d24c27326..dd72551d9 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -974,7 +974,16 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From bdd0f85704a1c257a89811a1cece728425fe6e9b Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Thu, 5 Dec 2024 15:12:06 +0800 Subject: [PATCH 39/59] Fix the normalized_text in LibriTTS recipe (#1825) --- egs/libritts/TTS/local/prepare_tokens_libritts.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/egs/libritts/TTS/local/prepare_tokens_libritts.py b/egs/libritts/TTS/local/prepare_tokens_libritts.py index faeb611f5..cdc39ea6b 100755 --- a/egs/libritts/TTS/local/prepare_tokens_libritts.py +++ b/egs/libritts/TTS/local/prepare_tokens_libritts.py @@ -31,15 +31,6 @@ from piper_phonemize import phonemize_espeak from tqdm.auto import tqdm -def remove_punc_to_upper(text: str) -> str: - text = text.replace("‘", "'") - text = text.replace("’", "'") - tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") - s_list = [x.upper() if x in tokens else " " for x in text] - s = " ".join("".join(s_list).split()).strip() - return s - - def prepare_tokens_libritts(): output_dir = Path("data/spectrogram") prefix = "libritts" @@ -72,7 +63,7 @@ def prepare_tokens_libritts(): for t in tokens_list: tokens.extend(t) cut.tokens = tokens - cut.supervisions[0].normalized_text = remove_punc_to_upper(text) + cut.supervisions[0].normalized_text = text new_cuts.append(cut) From 6e6b022e413a49cc2cf1c14995db39656e0ad85b Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 6 Dec 2024 16:14:51 +0800 Subject: [PATCH 40/59] performed end to end testing to the VALL-E recipe (#1818) * added the missing ``visualize`` function * minor fixes --- ...te_neural_codec_and_prepare_text_tokens.py | 22 +++-- egs/wenetspeech4tts/TTS/valle/infer.py | 2 +- .../TTS/valle/requirements.txt | 2 + egs/wenetspeech4tts/TTS/valle/train.py | 9 +- egs/wenetspeech4tts/TTS/valle/valle.py | 85 +++++++++++++++++++ 5 files changed, 109 insertions(+), 11 deletions(-) create mode 100644 egs/wenetspeech4tts/TTS/valle/requirements.txt diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py index 5494bf340..7de2c6202 100755 --- a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py +++ b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -516,9 +516,19 @@ def main(): for idx, part in enumerate(cut_sets): if args.audio_extractor: if args.audio_extractor == "Encodec": - storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}" + if split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_encodec_{partition}" + ) else: - storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}" + if split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}" + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_fbank_{partition}" + ) if args.prefix.lower() in [ "ljspeech", @@ -587,9 +597,11 @@ def main(): ].normalized_text, "normalized_text is None" # Save each part with an index if split > 1 - cuts_filename = ( - f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}" - ) + if split > 1: + cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}" + else: + cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}" + part.to_file(f"{args.output_dir}/{cuts_filename}") logging.info(f"Saved {cuts_filename}") diff --git a/egs/wenetspeech4tts/TTS/valle/infer.py b/egs/wenetspeech4tts/TTS/valle/infer.py index fd7ba9f21..44a251c56 100644 --- a/egs/wenetspeech4tts/TTS/valle/infer.py +++ b/egs/wenetspeech4tts/TTS/valle/infer.py @@ -86,7 +86,7 @@ def get_args(): parser.add_argument( "--checkpoint", type=str, - default="exp/vallf_nano_full/checkpoint-100000.pt", + default="./valle/exp/checkpoint-100000.pt", help="Path to the saved checkpoint.", ) diff --git a/egs/wenetspeech4tts/TTS/valle/requirements.txt b/egs/wenetspeech4tts/TTS/valle/requirements.txt new file mode 100644 index 000000000..06958dbea --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/requirements.txt @@ -0,0 +1,2 @@ +phonemizer==3.2.1 +git+https://github.com/facebookresearch/encodec.git \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py index fde209511..e9ec548f3 100755 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -4,6 +4,7 @@ # Mingshuang Luo) # Copyright 2023 (authors: Feiteng Li) # Copyright 2024 (authors: Yuekai Zhang) +# Copyright 2024 Tsinghua University (authors: Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -48,10 +49,8 @@ python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max import argparse import copy import logging -import os import random import warnings -from contextlib import nullcontext from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -216,7 +215,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="exp/valle_dev", + default="./valle/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -686,9 +685,9 @@ def compute_validation_loss( output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") output_dir.mkdir(parents=True, exist_ok=True) if isinstance(model, DDP): - model.module.visualize(predicts, batch, output_dir=output_dir) + model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir) else: - model.visualize(predicts, batch, output_dir=output_dir) + model.visualize(predicts, batch, tokenizer, output_dir=output_dir) return tot_loss diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index b2eb8ae69..4bfa2b577 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -19,8 +19,11 @@ import random from functools import partial from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn +from tokenizer import TextTokenCollater from torch import Tensor from torch.nn import Linear, Module from torch.nn import functional as F @@ -1658,6 +1661,88 @@ class VALLE(nn.Module): assert len(codes) == 8 return torch.stack(codes, dim=-1) + def visualize( + self, + predicts: Tuple[torch.Tensor], + batch: Dict[str, Union[List, torch.Tensor]], + tokenizer: TextTokenCollater, + output_dir: str, + limit: int = 4, + ) -> None: + audio_features = batch["features"].to("cpu").detach().numpy() + audio_features_lens = batch["features_lens"].to("cpu").detach().numpy() + + tokens = batch["tokens"] + text_tokens, text_tokens_lens = tokenizer(tokens) + assert text_tokens.ndim == 2 + + texts = batch["text"] + utt_ids = [cut.id for cut in batch["cut"]] + + encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() + decoder_outputs = predicts[1] + if isinstance(decoder_outputs, list): + decoder_outputs = decoder_outputs[-1] + decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() + + vmin, vmax = 0, 1024 # Encodec + if decoder_outputs.dtype == np.float32: + vmin, vmax = -6, 0 # Fbank + + num_figures = 3 + for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): + _ = plt.figure(figsize=(14, 8 * num_figures)) + + S = text_tokens_lens[b] + T = audio_features_lens[b] + + # encoder + plt.subplot(num_figures, 1, 1) + plt.title(f"Text: {text}") + plt.imshow( + X=np.transpose(encoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + ) + plt.gca().invert_yaxis() + plt.axvline(x=S - 0.4, linewidth=2, color="r") + plt.xlabel("Encoder Output") + plt.colorbar() + + # decoder + plt.subplot(num_figures, 1, 2) + plt.imshow( + X=np.transpose(decoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Output") + plt.colorbar() + + # target + plt.subplot(num_figures, 1, 3) + plt.imshow( + X=np.transpose(audio_features[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Target") + plt.colorbar() + + plt.savefig(f"{output_dir}/{utt_id}.png") + plt.close() + # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering( From 1c4dd464a0dcba042ea0050c8b7a0416799025b3 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Dec 2024 03:18:15 +0800 Subject: [PATCH 41/59] Performed end to end testing on the matcha recipe (#1797) * minor fixes to the `ljspeech/matcha` recipe --- .github/scripts/ljspeech/TTS/run-matcha.sh | 2 +- egs/ljspeech/TTS/README.md | 4 +- egs/ljspeech/TTS/local/audio.py | 1 + .../TTS/local/compute_fbank_ljspeech.py | 91 +---- egs/ljspeech/TTS/local/fbank.py | 1 + .../TTS/matcha/compute_fbank_ljspeech.py | 1 - .../TTS/matcha/export_onnx_hifigan.py | 2 +- egs/ljspeech/TTS/matcha/fbank.py | 88 +++++ egs/ljspeech/TTS/matcha/infer.py | 328 ++++++++++++++++++ egs/ljspeech/TTS/matcha/inference.py | 199 ----------- .../TTS/matcha/models/components/decoder.py | 2 +- .../matcha/models/components/flow_matching.py | 2 +- .../matcha/models/components/text_encoder.py | 2 +- egs/ljspeech/TTS/matcha/models/matcha_tts.py | 8 +- .../TTS/matcha/monotonic_align/.gitignore | 2 +- .../TTS/matcha/monotonic_align/__init__.py | 5 +- .../TTS/matcha/monotonic_align/core.pyx | 2 - .../TTS/matcha/monotonic_align/setup.py | 30 +- egs/ljspeech/TTS/matcha/requirements.txt | 1 + egs/ljspeech/TTS/matcha/train.py | 13 +- egs/ljspeech/TTS/matcha/tts_datamodule.py | 15 +- egs/ljspeech/TTS/prepare.sh | 30 +- egs/ljspeech/TTS/vits/infer.py | 2 +- .../TTS/vits/monotonic_align/.gitignore | 3 + egs/ljspeech/TTS/vits/test_model.py | 1 - 25 files changed, 485 insertions(+), 350 deletions(-) create mode 120000 egs/ljspeech/TTS/local/audio.py create mode 120000 egs/ljspeech/TTS/local/fbank.py delete mode 120000 egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py create mode 100644 egs/ljspeech/TTS/matcha/fbank.py create mode 100755 egs/ljspeech/TTS/matcha/infer.py delete mode 100755 egs/ljspeech/TTS/matcha/inference.py create mode 100644 egs/ljspeech/TTS/vits/monotonic_align/.gitignore diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 37e1bc320..0876cb47f 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -56,7 +56,7 @@ function infer() { curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 - ./matcha/inference.py \ + ./matcha/infer.py \ --epoch 1 \ --exp-dir ./matcha/exp \ --tokens data/tokens.txt \ diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 1cd6e8fd7..82850cd04 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -131,12 +131,12 @@ To inference, use: wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 -./matcha/inference \ +./matcha/synth.py \ --exp-dir ./matcha/exp-new-3 \ --epoch 4000 \ --tokens ./data/tokens.txt \ --vocoder ./generator_v1 \ - --input-text "how are you doing?" + --input-text "how are you doing?" \ --output-wav ./generated.wav ``` diff --git a/egs/ljspeech/TTS/local/audio.py b/egs/ljspeech/TTS/local/audio.py new file mode 120000 index 000000000..b70d91c92 --- /dev/null +++ b/egs/ljspeech/TTS/local/audio.py @@ -0,0 +1 @@ +../matcha/audio.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 5152ae675..296f9a4f4 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -27,102 +27,17 @@ The generated fbank features are saved in data/fbank. import argparse import logging import os -from dataclasses import dataclass from pathlib import Path -from typing import Union -import numpy as np import torch +from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, LilcomChunkyWriter, load_manifest from lhotse.audio import RecordingSet -from lhotse.features.base import FeatureExtractor, register_extractor from lhotse.supervision import SupervisionSet -from lhotse.utils import Seconds, compute_num_frames -from matcha.audio import mel_spectrogram from icefall.utils import get_executor -@dataclass -class MyFbankConfig: - n_fft: int - n_mels: int - sampling_rate: int - hop_length: int - win_length: int - f_min: float - f_max: float - - -@register_extractor -class MyFbank(FeatureExtractor): - - name = "MyFbank" - config_type = MyFbankConfig - - def __init__(self, config): - super().__init__(config=config) - - @property - def device(self) -> Union[str, torch.device]: - return self.config.device - - def feature_dim(self, sampling_rate: int) -> int: - return self.config.n_mels - - def extract( - self, - samples: np.ndarray, - sampling_rate: int, - ) -> torch.Tensor: - # Check for sampling rate compatibility. - expected_sr = self.config.sampling_rate - assert sampling_rate == expected_sr, ( - f"Mismatched sampling rate: extractor expects {expected_sr}, " - f"got {sampling_rate}" - ) - samples = torch.from_numpy(samples) - assert samples.ndim == 2, samples.shape - assert samples.shape[0] == 1, samples.shape - - mel = ( - mel_spectrogram( - samples, - self.config.n_fft, - self.config.n_mels, - self.config.sampling_rate, - self.config.hop_length, - self.config.win_length, - self.config.f_min, - self.config.f_max, - center=False, - ) - .squeeze() - .t() - ) - - assert mel.ndim == 2, mel.shape - assert mel.shape[1] == self.config.n_mels, mel.shape - - num_frames = compute_num_frames( - samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate - ) - - if mel.shape[0] > num_frames: - mel = mel[:num_frames] - elif mel.shape[0] < num_frames: - mel = mel.unsqueeze(0) - mel = torch.nn.functional.pad( - mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" - ).squeeze(0) - - return mel.numpy() - - @property - def frame_shift(self) -> Seconds: - return self.config.hop_length / self.config.sampling_rate - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int): logging.info(f"num_jobs: {num_jobs}") logging.info(f"src_dir: {src_dir}") logging.info(f"output_dir: {output_dir}") - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=22050, @@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int): src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet ) - extractor = MyFbank(config) + extractor = MatchaFbank(config) with get_executor() as ex: # Initialize the executor only once. cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" diff --git a/egs/ljspeech/TTS/local/fbank.py b/egs/ljspeech/TTS/local/fbank.py new file mode 120000 index 000000000..5bcf1fde5 --- /dev/null +++ b/egs/ljspeech/TTS/local/fbank.py @@ -0,0 +1 @@ +../matcha/fbank.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py deleted file mode 120000 index 85255ba0c..000000000 --- a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py +++ /dev/null @@ -1 +0,0 @@ -../local/compute_fbank_ljspeech.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py index 63d1fac20..5c96b3bc7 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -7,7 +7,7 @@ from typing import Any, Dict import onnx import torch -from inference import load_vocoder +from infer import load_vocoder def add_meta_data(filename: str, meta_data: Dict[str, Any]): diff --git a/egs/ljspeech/TTS/matcha/fbank.py b/egs/ljspeech/TTS/matcha/fbank.py new file mode 100644 index 000000000..d729fa425 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/fbank.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +from audio import mel_spectrogram +from lhotse.features.base import FeatureExtractor, register_extractor +from lhotse.utils import Seconds, compute_num_frames + + +@dataclass +class MatchaFbankConfig: + n_fft: int + n_mels: int + sampling_rate: int + hop_length: int + win_length: int + f_min: float + f_max: float + + +@register_extractor +class MatchaFbank(FeatureExtractor): + + name = "MatchaFbank" + config_type = MatchaFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: np.ndarray, + sampling_rate: int, + ) -> torch.Tensor: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + samples = torch.from_numpy(samples) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = ( + mel_spectrogram( + samples, + self.config.n_fft, + self.config.n_mels, + self.config.sampling_rate, + self.config.hop_length, + self.config.win_length, + self.config.f_min, + self.config.f_max, + center=False, + ) + .squeeze() + .t() + ) + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + return mel.numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate diff --git a/egs/ljspeech/TTS/matcha/infer.py b/egs/ljspeech/TTS/matcha/infer.py new file mode 100755 index 000000000..0b221d5c5 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/infer.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +import torch.nn as nn +from hifigan.config import v1, v2, v3 +from hifigan.denoiser import Denoiser +from hifigan.models import Generator as HiFiGAN +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + # The following arguments are used for inference on single text + parser.add_argument( + "--input-text", + type=str, + required=False, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=False, + help="The filename of the wave to save the generated speech", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for LJSpeech)", + ) + + return parser + + +def load_vocoder(checkpoint_path: Path) -> nn.Module: + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform( + mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module +) -> torch.Tensor: + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.squeeze() + + +def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: + x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.long, device=device) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesize( + model: nn.Module, + tokenizer: Tokenizer, + n_timesteps: int, + text: str, + length_scale: float, + temperature: float, + device: str = "cpu", + spks=None, +) -> dict: + text_processed = process_text(text=text, tokenizer=tokenizer, device=device) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + denoiser: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + texts = [c.supervisions[0].normalized_text for c in batch["cut"]] + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + for i in range(batch_size): + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=texts[i], + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", + data=output["waveform"], + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", + data=audio[i].numpy(), + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.inference_mode() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + + # we need cut ids to organize tts results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + vocoder.to(device) + + denoiser = Denoiser(vocoder, mode="zeros") + denoiser.to(device) + + if params.input_text is not None and params.output_wav is not None: + logging.info("Synthesizing a single text") + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=params.input_text, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) + else: + logging.info("Decoding the test set") + infer_dataset( + dl=test_dl, + params=params, + model=model, + vocoder=vocoder, + denoiser=denoiser, + tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py deleted file mode 100755 index 64abd8e50..000000000 --- a/egs/ljspeech/TTS/matcha/inference.py +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -import argparse -import datetime as dt -import json -import logging -from pathlib import Path - -import soundfile as sf -import torch -from matcha.hifigan.config import v1, v2, v3 -from matcha.hifigan.denoiser import Denoiser -from matcha.hifigan.models import Generator as HiFiGAN -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=4000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp-new-3", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--vocoder", - type=Path, - default="./generator_v1", - help="Path to the vocoder", - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/tokens.txt", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--input-text", - type=str, - required=True, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=True, - help="The filename of the wave to save the generated speech", - ) - - return parser - - -def load_vocoder(checkpoint_path): - checkpoint_path = str(checkpoint_path) - if checkpoint_path.endswith("v1"): - h = AttributeDict(v1) - elif checkpoint_path.endswith("v2"): - h = AttributeDict(v2) - elif checkpoint_path.endswith("v3"): - h = AttributeDict(v3) - else: - raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") - - hifigan = HiFiGAN(h).to("cpu") - hifigan.load_state_dict( - torch.load(checkpoint_path, map_location="cpu")["generator"] - ) - _ = hifigan.eval() - hifigan.remove_weight_norm() - return hifigan - - -def to_waveform(mel, vocoder, denoiser): - audio = vocoder(mel).clamp(-1, 1) - audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() - return audio.cpu().squeeze() - - -def process_text(text: str, tokenizer): - x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) - x = torch.tensor(x, dtype=torch.long) - x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") - return {"x_orig": text, "x": x, "x_lengths": x_lengths} - - -def synthesise( - model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None -): - text_processed = process_text(text, tokenizer) - start_t = dt.datetime.now() - output = model.synthesise( - text_processed["x"], - text_processed["x_lengths"], - n_timesteps=n_timesteps, - temperature=temperature, - spks=spks, - length_scale=length_scale, - ) - # merge everything to one dict - output.update({"start_t": start_t, **text_processed}) - return output - - -@torch.inference_mode() -def main(): - parser = get_parser() - args = parser.parse_args() - params = get_params() - - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file(): - raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist") - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.eval() - - if not Path(params.vocoder).is_file(): - raise ValueError(f"{params.vocoder} does not exist") - - vocoder = load_vocoder(params.vocoder) - denoiser = Denoiser(vocoder, mode="zeros") - - # Number of ODE Solver steps - n_timesteps = 2 - - # Changes to the speaking rate - length_scale = 1.0 - - # Sampling temperature - temperature = 0.667 - - output = synthesise( - model=model, - tokenizer=tokenizer, - n_timesteps=n_timesteps, - text=params.input_text, - length_scale=length_scale, - temperature=temperature, - ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py index 14d19f5d4..102d87713 100644 --- a/egs/ljspeech/TTS/matcha/models/components/decoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/decoder.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from conformer import ConformerBlock from diffusers.models.activations import get_activation from einops import pack, rearrange, repeat -from matcha.models.components.transformer import BasicTransformerBlock +from models.components.transformer import BasicTransformerBlock class SinusoidalPosEmb(torch.nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 997689b1c..eb795ef32 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -2,7 +2,7 @@ from abc import ABC import torch import torch.nn.functional as F -from matcha.models.components.decoder import Decoder +from models.components.decoder import Decoder class BASECFM(torch.nn.Module, ABC): diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index ca77cba51..364ff1938 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -5,7 +5,7 @@ import math import torch import torch.nn as nn from einops import rearrange -from matcha.model import sequence_mask +from model import sequence_mask class LayerNorm(nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index 330d1dc47..fe0a72402 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -2,17 +2,17 @@ import datetime as dt import math import random -import matcha.monotonic_align as monotonic_align +import monotonic_align as monotonic_align import torch -from matcha.model import ( +from model import ( denormalize, duration_loss, fix_len_compatibility, generate_path, sequence_mask, ) -from matcha.models.components.flow_matching import CFM -from matcha.models.components.text_encoder import TextEncoder +from models.components.flow_matching import CFM +from models.components.text_encoder import TextEncoder class MatchaTTS(torch.nn.Module): # 🍵 diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore index 28bdad6b8..3def4ae26 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore +++ b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore @@ -1,3 +1,3 @@ build core.c -*.so +*.so \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py index 5b26fe474..f87ae1bd3 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py @@ -1,8 +1,7 @@ -# Copied from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py import numpy as np import torch -from matcha.monotonic_align.core import maximum_path_c + +from .core import maximum_path_c def maximum_path(value, mask): diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx index eabc7f273..091fcc3a5 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx +++ b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx @@ -1,5 +1,3 @@ -# Copied from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/core.pyx import numpy as np cimport cython diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py index df26c633e..beacf2e36 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py @@ -1,12 +1,30 @@ -# Copied from +# Modified from # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py -from distutils.core import setup - -import numpy from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + + +exts = [ + Extension( + name="core", + sources=["core.pyx"], + ) +] setup( name="monotonic_align", - ext_modules=cythonize("core.pyx"), - include_dirs=[numpy.get_include()], + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, ) diff --git a/egs/ljspeech/TTS/matcha/requirements.txt b/egs/ljspeech/TTS/matcha/requirements.txt index 5aadc8984..d7829c1e1 100644 --- a/egs/ljspeech/TTS/matcha/requirements.txt +++ b/egs/ljspeech/TTS/matcha/requirements.txt @@ -1,3 +1,4 @@ conformer==0.3.2 diffusers # developed using version ==0.25.0 librosa +einops \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 5e713fdfd..31135f623 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -14,9 +14,9 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed -from matcha.model import fix_len_compatibility -from matcha.models.matcha_tts import MatchaTTS -from matcha.tokenizer import Tokenizer +from model import fix_len_compatibility +from models.matcha_tts import MatchaTTS +from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -150,7 +150,7 @@ def _get_data_params() -> AttributeDict: "n_spks": 1, "n_fft": 1024, "n_feats": 80, - "sample_rate": 22050, + "sampling_rate": 22050, "hop_length": 256, "win_length": 1024, "f_min": 0, @@ -445,11 +445,6 @@ def train_one_epoch( saved_bad_model = False - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - def save_bad_model(suffix: str = ""): save_checkpoint( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index 8e37fc030..1e637b766 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -24,7 +24,7 @@ from pathlib import Path from typing import Any, Dict, Optional import torch -from compute_fbank_ljspeech import MyFbank, MyFbankConfig +from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, load_manifest_lazy from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, @@ -32,7 +32,6 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures DynamicBucketingSampler, PrecomputedFeatures, SimpleCutSampler, - SpecAugment, SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -177,7 +176,7 @@ class LJSpeechTtsDataModule: if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -189,7 +188,7 @@ class LJSpeechTtsDataModule: train = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) @@ -238,7 +237,7 @@ class LJSpeechTtsDataModule: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -250,7 +249,7 @@ class LJSpeechTtsDataModule: validate = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) else: @@ -282,7 +281,7 @@ class LJSpeechTtsDataModule: logging.info("About to create test dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -294,7 +293,7 @@ class LJSpeechTtsDataModule: test = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 6f16f8d47..ec5062933 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -25,26 +25,16 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: build monotonic_align lib" - if [ ! -d vits/monotonic_align/build ]; then - cd vits/monotonic_align - python3 setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib for vits already built" - fi - - if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then - pushd matcha/monotonic_align - python3 setup.py build - mv -v build/lib.*/matcha/monotonic_align/core.*.so . - rm -rf build - rm core.c - ls -lh - popd - else - log "monotonic_align lib for matcha-tts already built" - fi + log "Stage -1: build monotonic_align lib (used by vits and matcha recipes)" + for recipe in vits matcha; do + if [ ! -d $recipe/monotonic_align/build ]; then + cd $recipe/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib for $recipe already built" + fi + done fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 7be76e315..cf1067dfc 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -234,7 +234,7 @@ def main(): logging.info(f"Number of parameters in discriminator: {num_param_d}") logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - # we need cut ids to display recognition results. + # we need cut ids to organize tts results. args.return_cuts = True ljspeech = LJSpeechTtsDataModule(args) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/.gitignore b/egs/ljspeech/TTS/vits/monotonic_align/.gitignore new file mode 100644 index 000000000..3def4ae26 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/.gitignore @@ -0,0 +1,3 @@ +build +core.c +*.so \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py index 1de10f012..4faaa96a5 100755 --- a/egs/ljspeech/TTS/vits/test_model.py +++ b/egs/ljspeech/TTS/vits/test_model.py @@ -18,7 +18,6 @@ from tokenizer import Tokenizer from train import get_model, get_params -from vits import VITS def test_model_type(model_type): From 5c04f7bfb84a1f2f3b307d824a1355c9c8d30a20 Mon Sep 17 00:00:00 2001 From: goddamnVincent <84380030+goddamnVincent@users.noreply.github.com> Date: Sun, 8 Dec 2024 11:17:15 +0800 Subject: [PATCH 42/59] 'try to fix 'compute_fbank_kespeech_splits.py: error: unrecognized arguments: --speed-perturb true'' (#1812) --- .../ASR/local/compute_fbank_kespeech_dev_test.py | 12 +++++++++++- .../ASR/local/compute_fbank_kespeech_splits.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py index 6f75dbfa4..5e169e894 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py @@ -52,13 +52,19 @@ def get_parser(): default=80, help="""The number of mel bins for Fbank""", ) - parser.add_argument( "--whisper-fbank", type=str2bool, default=False, help="Use WhisperFbank instead of Fbank. Default: False.", ) + parser.add_argument( + "--speed-perturb", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser @@ -104,6 +110,10 @@ def compute_fbank_kespeech_dev_test(args): keep_overlapping=False, min_duration=None ) + if args.speed_perturb: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) logging.info("Computing features") cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py index c398411f6..6bb8af0d6 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py @@ -106,6 +106,14 @@ def get_parser(): default=False, help="Use WhisperFbank instead of Fbank. Default: False.", ) + + parser.add_argument( + "--speed-perturb", + type=str2bool, + default=False, + help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", + ) + return parser @@ -158,6 +166,11 @@ def compute_fbank_kespeech_splits(args): keep_overlapping=False, min_duration=None ) + if args.speed_perturb: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + logging.info("Computing features") cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, From d33f67817641b3911f91c2b76698b266290a5e01 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Dec 2024 16:37:24 +0800 Subject: [PATCH 43/59] fixed the formatting issue of PR#1812 (#1828) --- .../ASR/local/compute_fbank_kespeech_dev_test.py | 5 ++--- egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py | 4 +--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py index 5e169e894..2bbe28560 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_dev_test.py @@ -111,9 +111,8 @@ def compute_fbank_kespeech_dev_test(args): ) if args.speed_perturb: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + logging.info("Computing features") cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, diff --git a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py index 6bb8af0d6..fe7f87337 100755 --- a/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py +++ b/egs/multi_zh-hans/ASR/local/compute_fbank_kespeech_splits.py @@ -167,9 +167,7 @@ def compute_fbank_kespeech_splits(args): ) if args.speed_perturb: - cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) - ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) logging.info("Computing features") cut_set = cut_set.compute_and_store_features_batch( From 32b7a449e7ed87efdf0a49f74b01c846e831c8a3 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sun, 8 Dec 2024 17:36:08 +0800 Subject: [PATCH 44/59] removed unnecessary type check (#1827) --- egs/wenetspeech4tts/TTS/valle/valle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index 4bfa2b577..772317428 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -1686,8 +1686,6 @@ class VALLE(nn.Module): decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() vmin, vmax = 0, 1024 # Encodec - if decoder_outputs.dtype == np.float32: - vmin, vmax = -6, 0 # Fbank num_figures = 3 for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): From 08caa1e4e52f9c0684a91fcfce02487382fae45a Mon Sep 17 00:00:00 2001 From: zr_jin Date: Mon, 9 Dec 2024 22:59:29 +0800 Subject: [PATCH 45/59] minor fixes to the matcha recipe --- egs/ljspeech/TTS/matcha/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 31135f623..853042413 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -488,9 +488,10 @@ def train_one_epoch( loss = sum(losses.values()) - optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() loss_info = MetricsTracker() loss_info["samples"] = batch_size From a43480af47896329e82917d89ceee65f15afbf25 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Tue, 10 Dec 2024 11:15:49 +0800 Subject: [PATCH 46/59] fixed the not found python 3.8 env (#1830) --- .github/workflows/style_check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 0681ece60..2a077fa91 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -36,7 +36,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.8] + python-version: [3.10.15] fail-fast: false steps: From b7acf0f57b3ad03cb98752a6e57216dc539eee14 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Wed, 11 Dec 2024 14:33:47 +0800 Subject: [PATCH 47/59] minor fixes --- egs/ljspeech/TTS/README.md | 2 +- egs/ljspeech/TTS/matcha/onnx_pretrained.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 82850cd04..39280437b 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -131,7 +131,7 @@ To inference, use: wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 -./matcha/synth.py \ +./matcha/infer.py \ --exp-dir ./matcha/exp-new-3 \ --epoch 4000 \ --tokens ./data/tokens.txt \ diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index be34343d3..4eff9a084 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -8,7 +8,7 @@ import logging import onnxruntime as ort import soundfile as sf import torch -from inference import load_vocoder +from infer import load_vocoder from tokenizer import Tokenizer From 3e4da5f78160d3dba3bdf97968bd7ceb8c11631f Mon Sep 17 00:00:00 2001 From: Li Peng Date: Mon, 16 Dec 2024 10:24:16 +0800 Subject: [PATCH 48/59] Replace deprecated pytorch methods (#1814) * Replace deprecated pytorch methods - torch.cuda.amp.GradScaler(...) => torch.amp.GradScaler("cuda", ...) - torch.cuda.amp.autocast(...) => torch.amp.autocast("cuda", ...) * Replace `with autocast(...)` with `with autocast("cuda", ...)` Co-authored-by: Li Peng --- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless3/model.py | 4 ++-- .../ASR/pruned_transducer_stateless3/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_bbpe/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- egs/aishell/ASR/whisper/train.py | 8 ++++---- egs/aishell/ASR/zipformer/train.py | 8 ++++---- egs/aishell/ASR/zipformer/train_bbpe.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR_v2/pruned_transducer_stateless7/train.py | 8 ++++---- egs/ami/ASR/pruned_transducer_stateless7/train.py | 8 ++++---- egs/ami/SURT/dprnn_zipformer/train.py | 6 +++--- egs/ami/SURT/dprnn_zipformer/train_adapt.py | 6 +++--- egs/audioset/AT/zipformer/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../finetune.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- egs/commonvoice/ASR/zipformer/train.py | 8 ++++---- egs/commonvoice/ASR/zipformer/train_char.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- egs/gigaspeech/ASR/zipformer/train.py | 8 ++++---- egs/gigaspeech/KWS/zipformer/finetune.py | 6 +++--- egs/gigaspeech/KWS/zipformer/train.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- egs/ksponspeech/ASR/zipformer/train.py | 8 ++++---- egs/libricss/SURT/dprnn_zipformer/model.py | 4 ++-- egs/libricss/SURT/dprnn_zipformer/scaling.py | 6 +++--- egs/libricss/SURT/dprnn_zipformer/train.py | 6 +++--- egs/libricss/SURT/dprnn_zipformer/train_adapt.py | 6 +++--- egs/libriheavy/ASR/zipformer/train.py | 8 ++++---- .../ASR/zipformer_prompt_asr/model_baseline.py | 4 ++-- .../ASR/zipformer_prompt_asr/model_with_BERT.py | 4 ++-- egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py | 10 +++++----- .../ASR/zipformer_prompt_asr/train_baseline.py | 8 ++++---- .../ASR/zipformer_prompt_asr/train_bert_encoder.py | 8 ++++---- egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py | 4 ++-- egs/librilight/SSL/zipformer/finetune.py | 8 ++++---- egs/librilight/SSL/zipformer/pretrain.py | 8 ++++---- egs/librispeech/ASR/conformer_ctc2/train.py | 8 ++++---- egs/librispeech/ASR/conformer_ctc3/train.py | 8 ++++---- .../ASR/conv_emformer_transducer_stateless/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../ASR/conv_emformer_transducer_stateless2/train.py | 8 ++++---- .../ASR/lstm_transducer_stateless/model.py | 4 ++-- .../ASR/lstm_transducer_stateless/train.py | 8 ++++---- .../ASR/lstm_transducer_stateless2/model.py | 4 ++-- .../ASR/lstm_transducer_stateless2/train.py | 8 ++++---- .../ASR/lstm_transducer_stateless3/train.py | 8 ++++---- egs/librispeech/ASR/pruned2_knowledge/model.py | 4 ++-- egs/librispeech/ASR/pruned2_knowledge/sampling.py | 6 +++--- egs/librispeech/ASR/pruned2_knowledge/train.py | 8 ++++---- .../ASR/pruned_stateless_emformer_rnnt2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/model.py | 4 ++-- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless3/model.py | 4 ++-- .../ASR/pruned_transducer_stateless3/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless4/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless6/model.py | 4 ++-- .../ASR/pruned_transducer_stateless6/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/finetune.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/model.py | 4 ++-- .../ASR/pruned_transducer_stateless7/scaling.py | 6 +++--- .../ASR/pruned_transducer_stateless7/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/zipformer.py | 2 +- .../ASR/pruned_transducer_stateless7_ctc/model.py | 4 ++-- .../ASR/pruned_transducer_stateless7_ctc/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_ctc_bs/model.py | 4 ++-- .../ASR/pruned_transducer_stateless7_ctc_bs/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- .../zipformer.py | 2 +- .../zipformer_for_ncnn_export_only.py | 2 +- .../train.py | 8 ++++---- .../ASR/pruned_transducer_stateless8/model.py | 4 ++-- .../ASR/pruned_transducer_stateless8/train.py | 8 ++++---- egs/librispeech/ASR/tiny_transducer_ctc/train.py | 8 ++++---- egs/librispeech/ASR/zipformer/finetune.py | 8 ++++---- egs/librispeech/ASR/zipformer/model.py | 4 ++-- egs/librispeech/ASR/zipformer/scaling.py | 10 +++++----- egs/librispeech/ASR/zipformer/train.py | 8 ++++---- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- egs/librispeech/ASR/zipformer_adapter/train.py | 8 ++++---- egs/librispeech/ASR/zipformer_adapter/zipformer.py | 2 +- egs/librispeech/ASR/zipformer_ctc/train.py | 6 +++--- egs/librispeech/ASR/zipformer_lora/finetune.py | 8 ++++---- egs/librispeech/ASR/zipformer_lora/scaling.py | 10 +++++----- egs/librispeech/ASR/zipformer_lora/train.py | 8 ++++---- egs/librispeech/ASR/zipformer_lora/zipformer.py | 2 +- egs/librispeech/ASR/zipformer_mmi/train.py | 8 ++++---- egs/librispeech/SSL/hubert/finetune.py | 8 ++++---- egs/librispeech/SSL/hubert/finetune_ce.py | 8 ++++---- egs/librispeech/SSL/hubert/model.py | 4 ++-- egs/librispeech/SSL/hubert/pretrain.py | 8 ++++---- egs/librispeech/SSL/hubert/pretrain_ce.py | 8 ++++---- egs/librispeech/SSL/zipformer/finetune.py | 8 ++++---- egs/librispeech/SSL/zipformer/model.py | 4 ++-- egs/librispeech/SSL/zipformer/pretrain.py | 8 ++++---- egs/librispeech/SSL/zipformer/zipformer.py | 2 +- egs/librispeech/WSASR/conformer_ctc2/train.py | 8 ++++---- egs/librispeech/WSASR/conformer_ctc2/train_phone.py | 8 ++++---- egs/libritts/ASR/zipformer/train.py | 12 ++++++------ egs/libritts/CODEC/encodec/encodec.py | 6 +++--- egs/libritts/CODEC/encodec/train.py | 12 ++++++------ egs/libritts/TTS/vits/train.py | 12 ++++++------ egs/ljspeech/TTS/matcha/train.py | 6 +++--- egs/ljspeech/TTS/vits/train.py | 12 ++++++------ egs/ljspeech/TTS/vits/utils.py | 2 +- egs/ljspeech/TTS/vits/vits.py | 6 +++--- egs/mdcc/ASR/zipformer/train.py | 8 ++++---- egs/mgb2/ASR/pruned_transducer_stateless5/train.py | 8 ++++---- egs/multi_zh-hans/ASR/whisper/train.py | 8 ++++---- egs/multi_zh-hans/ASR/zipformer/train.py | 8 ++++---- egs/multi_zh_en/ASR/zipformer/train.py | 8 ++++---- .../ASR/zipformer/do_not_use_it_directly.py | 8 ++++---- egs/reazonspeech/ASR/zipformer/train.py | 8 ++++---- egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py | 4 ++-- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- egs/spgispeech/ASR/zipformer/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_bbpe/train.py | 8 ++++---- egs/tedlium3/ASR/conformer_ctc2/train.py | 8 ++++---- egs/tedlium3/ASR/zipformer/model.py | 4 ++-- egs/tedlium3/ASR/zipformer/train.py | 8 ++++---- egs/vctk/TTS/vits/train.py | 12 ++++++------ .../ASR/pruned_transducer_stateless2/finetune.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- egs/wenetspeech/ASR/whisper/train.py | 8 ++++---- egs/wenetspeech/ASR/zipformer/train.py | 8 ++++---- egs/wenetspeech/KWS/zipformer/finetune.py | 6 +++--- egs/wenetspeech/KWS/zipformer/train.py | 8 ++++---- egs/wenetspeech4tts/TTS/valle/train.py | 12 +++++++----- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/train.py | 8 ++++---- icefall/checkpoint.py | 2 +- icefall/rnn_lm/train.py | 4 ++-- icefall/transformer_lm/train.py | 4 ++-- 147 files changed, 520 insertions(+), 518 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index fa809b768..9088378fa 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -638,7 +638,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -843,7 +843,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -912,7 +912,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index 60f014c48..dda098e99 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -60,7 +60,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -688,7 +688,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -888,7 +888,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -989,7 +989,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index a4dda0d6d..cafc9d1bb 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -184,7 +184,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -219,7 +219,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index 7c23041ca..bf60c4fad 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -79,7 +79,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -797,7 +797,7 @@ def train_one_epoch( aishell = is_aishell(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1096,7 +1096,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py index 058d0ff6b..9a9d92c20 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py @@ -74,7 +74,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -812,7 +812,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1107,7 +1107,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index 2dc835f3b..ede2bd3e5 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -70,7 +70,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -809,7 +809,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1107,7 +1107,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index 811269989..be48d6dde 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -802,7 +802,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1102,7 +1102,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 6653d9d9c..e3387e670 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -813,7 +813,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1105,7 +1105,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1205,7 +1205,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py index f3b0f1e11..cba312214 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py @@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -812,7 +812,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1104,7 +1104,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index d77f8c270..e84dcf156 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -62,7 +62,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.functional import pad as pad_tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -514,7 +514,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -608,7 +608,7 @@ def train_one_epoch( ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -812,7 +812,7 @@ def run(rank, world_size, args): train_dl = aishell.train_dataloaders(aishell.train_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index cd253c597..ab568b20f 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -71,7 +71,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -910,7 +910,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1201,7 +1201,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1302,7 +1302,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index 46a5506db..2dac0cc64 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -61,7 +61,7 @@ from lhotse.cut import Cut from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( @@ -495,7 +495,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -795,7 +795,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -895,7 +895,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 8c7448d4c..772d9e6bf 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -75,7 +75,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -734,7 +734,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -963,7 +963,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1062,7 +1062,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index a354f761e..0eb9271f5 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,7 @@ from local.text_normalize import text_normalize from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -727,7 +727,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) # print(batch["supervisions"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -963,7 +963,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1034,7 +1034,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 30154291d..2b1b6f9b4 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -638,7 +638,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -843,7 +843,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -912,7 +912,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index 30879d8d2..e321deeb1 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -782,7 +782,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1031,7 +1031,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1127,7 +1127,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index d62cdadb7..97ebc5bcf 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -773,7 +773,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1034,7 +1034,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1134,7 +1134,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py index adc6a8495..9e77c0527 100755 --- a/egs/ami/SURT/dprnn_zipformer/train.py +++ b/egs/ami/SURT/dprnn_zipformer/train.py @@ -61,7 +61,7 @@ from model import SURT from optim import Eden, ScaledAdam from scaling import ScaledLinear, ScaledLSTM from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -1067,7 +1067,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1314,7 +1314,7 @@ def run(rank, world_size, args): ) valid_dl = ami.valid_dataloaders(dev_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py index ac5b0dadc..0647a7c78 100755 --- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py @@ -61,7 +61,7 @@ from model import SURT from optim import Eden, ScaledAdam from scaling import ScaledLinear, ScaledLSTM from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -1058,7 +1058,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1305,7 +1305,7 @@ def run(rank, world_size, args): ) valid_dl = ami.valid_dataloaders(dev_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 67c703364..9532ed906 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -53,7 +53,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -799,7 +799,7 @@ def train_one_epoch( num_samples += batch_size try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1057,7 +1057,7 @@ def run(rank, world_size, args): valid_cuts = audioset.audioset_eval_cuts() valid_dl = audioset.valid_dataloaders(valid_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1148,7 +1148,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 5e98084ec..486ab73df 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -825,7 +825,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1120,7 +1120,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1220,7 +1220,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index aefe88f3f..fa241abe7 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -818,7 +818,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1109,7 +1109,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1209,7 +1209,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py index 976004eca..8905dc617 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -68,7 +68,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -895,7 +895,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1193,7 +1193,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1293,7 +1293,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index 67e1a8133..8260c4985 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -840,7 +840,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1137,7 +1137,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1237,7 +1237,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index 271014db0..c0219df19 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -969,7 +969,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1265,7 +1265,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1365,7 +1365,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index 0aa7856cc..639e1067a 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -67,7 +67,7 @@ from lhotse.cut import Cut from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( @@ -604,7 +604,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -784,7 +784,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, @@ -979,7 +979,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 6d256308c..661bfa6ca 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -67,7 +67,7 @@ from model import Transducer from optim import Eden, ScaledAdam from tokenizer import Tokenizer from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -839,7 +839,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1146,7 +1146,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1246,7 +1246,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index ef7ea9013..8f07fc42f 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -67,7 +67,7 @@ from model import Transducer from optim import Eden, ScaledAdam from tokenizer import Tokenizer from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -838,7 +838,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1145,7 +1145,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1245,7 +1245,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index a7772b62f..e0e11fc70 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -675,7 +675,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -873,7 +873,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -944,7 +944,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index 4c122effe..5092ef8cb 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -958,7 +958,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1217,7 +1217,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1317,7 +1317,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py index a7ba56127..49e8aef1a 100755 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -73,7 +73,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( @@ -291,7 +291,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -570,7 +570,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index 39d8fc6cd..f2283cb03 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -961,7 +961,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1220,7 +1220,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py index bf50bf5ea..30d9f0e51 100755 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -61,7 +61,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -805,7 +805,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1096,7 +1096,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1196,7 +1196,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index 485ea69c9..5f6ee7cca 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -70,7 +70,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -942,7 +942,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1233,7 +1233,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1333,7 +1333,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libricss/SURT/dprnn_zipformer/model.py b/egs/libricss/SURT/dprnn_zipformer/model.py index 688e1e78d..0e88357d1 100644 --- a/egs/libricss/SURT/dprnn_zipformer/model.py +++ b/egs/libricss/SURT/dprnn_zipformer/model.py @@ -140,7 +140,7 @@ class SURT(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -175,7 +175,7 @@ class SURT(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling.py b/egs/libricss/SURT/dprnn_zipformer/scaling.py index 4040a7b89..d46cb224e 100644 --- a/egs/libricss/SURT/dprnn_zipformer/scaling.py +++ b/egs/libricss/SURT/dprnn_zipformer/scaling.py @@ -287,7 +287,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -1065,7 +1065,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor): (x_orig,) = ctx.saved_tensors with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1263,7 +1263,7 @@ class MaxEig(torch.nn.Module): ): return _no_op(x) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): eps = 1.0e-20 orig_x = x x = x.to(torch.float32) diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py index 148cafd4b..33ea7c5a6 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train.py +++ b/egs/libricss/SURT/dprnn_zipformer/train.py @@ -69,7 +69,7 @@ from model import SURT from optim import Eden, ScaledAdam from scaling import ScaledLSTM from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -1096,7 +1096,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1343,7 +1343,7 @@ def run(rank, world_size, args): train_dl_ov40 = libricss.train_dataloaders(train_cuts_ov40) valid_dl = libricss.valid_dataloaders(dev_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py index 8c37430ec..82b61baa0 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py @@ -67,7 +67,7 @@ from model import SURT from optim import Eden, ScaledAdam from scaling import ScaledLinear, ScaledLSTM from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -985,7 +985,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1237,7 +1237,7 @@ def run(rank, world_size, args): ) valid_dl = libricss.valid_dataloaders(dev_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index 357e8a827..524273ec5 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -78,7 +78,7 @@ from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from text_normalization import remove_punc_to_upper from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -958,7 +958,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1268,7 +1268,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1367,7 +1367,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py index 77b4057c4..66328bb89 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py @@ -186,7 +186,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -221,7 +221,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py index 21c7b4fac..80fbf09f0 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py @@ -245,7 +245,7 @@ class PromptedTransducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -287,7 +287,7 @@ class PromptedTransducer(nn.Module): logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py index 0e6764ba0..a260d828e 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py @@ -271,7 +271,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -685,7 +685,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -940,7 +940,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1280,7 +1280,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1351,7 +1351,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index 93f7e1248..bfca5a0db 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -89,7 +89,7 @@ from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from text_normalization import train_text_normalization, upper_only_alpha from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -975,7 +975,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1271,7 +1271,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1371,7 +1371,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index 2a2c206aa..36c6d6464 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -103,7 +103,7 @@ from text_normalization import ( upper_only_alpha, ) from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1321,7 +1321,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1647,7 +1647,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1749,7 +1749,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py index d1cf90ffb..405c95acc 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py @@ -1561,7 +1561,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) @@ -1844,7 +1844,7 @@ class MultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librilight/SSL/zipformer/finetune.py b/egs/librilight/SSL/zipformer/finetune.py index 50dbd5f2d..568096c6a 100644 --- a/egs/librilight/SSL/zipformer/finetune.py +++ b/egs/librilight/SSL/zipformer/finetune.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -1116,7 +1116,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1407,7 +1407,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1505,7 +1505,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librilight/SSL/zipformer/pretrain.py b/egs/librilight/SSL/zipformer/pretrain.py index 5728dbe75..019f77ea3 100644 --- a/egs/librilight/SSL/zipformer/pretrain.py +++ b/egs/librilight/SSL/zipformer/pretrain.py @@ -57,7 +57,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from ssl_datamodule import LibriLightDataModule from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -936,7 +936,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1229,7 +1229,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index c4a13b101..b0b5da1c0 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -65,7 +65,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -676,7 +676,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -965,7 +965,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1036,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index a2f1125ca..7e819a2d8 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -76,7 +76,7 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -743,7 +743,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1004,7 +1004,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1073,7 +1073,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index ca21bd6bf..130a7c97f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -772,7 +772,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1002,7 +1002,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1071,7 +1071,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py index d614f0914..16ae4e4e2 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -774,7 +774,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1003,7 +1003,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1074,7 +1074,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 23ddb6bec..28d094a76 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -772,7 +772,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1001,7 +1001,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1072,7 +1072,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e7bad7ed8..1ec9a8fc6 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -156,7 +156,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -192,7 +192,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index feb81d500..1e50ce090 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -66,7 +66,7 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -763,7 +763,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1023,7 +1023,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1092,7 +1092,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..a758c550d 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 4fc4fa7f8..4d4f3e132 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -74,7 +74,7 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -848,7 +848,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1176,7 +1176,7 @@ def run(rank, world_size, args): else: logging.info("Skip scan_pessimistic_batches_for_oom") - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1247,7 +1247,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index 2c1cef3a3..ae4cd1c6a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -66,7 +66,7 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -793,7 +793,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1067,7 +1067,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1136,7 +1136,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..2ffea06e7 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -141,7 +141,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -176,7 +176,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 5b595c76c..3d2fdd6d8 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -10,7 +10,7 @@ from typing import Optional, Tuple import torch from scaling import ScaledLinear from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +from torch.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the @@ -330,14 +330,14 @@ def _test_knowledge_base_lookup_autocast(): optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) - scaler = GradScaler(enabled=True) + scaler = GradScaler("cuda", enabled=True) start = timeit.default_timer() for epoch in range(150): for n, (x, y) in enumerate(train_pairs): y_out = m(x) - with torch.cuda.amp.autocast(enabled=True): + with torch.amp.autocast("cuda", enabled=True): loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 931341cc4..8c117dd60 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -650,7 +650,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -868,7 +868,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -937,7 +937,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 2b872f1d5..b25a84a6b 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from noam import Noam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -693,7 +693,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -939,7 +939,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1004,7 +1004,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..59ed8310c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -157,7 +157,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -193,7 +193,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 6c19f2cb0..e86ec8052 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -78,7 +78,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -759,7 +759,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1000,7 +1000,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1067,7 +1067,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..0495c8a29 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index fdafa5a87..8ef207518 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -74,7 +74,7 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -827,7 +827,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1126,7 +1126,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1195,7 +1195,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 875b03f7f..b6682908b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -789,7 +789,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,7 +1047,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1116,7 +1116,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 66dc5f991..2b559a27c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -814,7 +814,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1078,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1147,7 +1147,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..20b730a08 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -185,7 +185,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -220,7 +220,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 8f033cb9a..93663505a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -781,7 +781,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1039,7 +1039,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1108,7 +1108,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index e7546ec45..d29010a23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -903,7 +903,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1219,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1319,7 +1319,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index add0e6a18..49076b96f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -150,7 +150,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 30a737061..16d86fe2d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -289,7 +289,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -669,7 +669,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor): (x_orig,) = ctx.saved_tensors with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -867,7 +867,7 @@ class MaxEig(torch.nn.Module): ): return _no_op(x) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): eps = 1.0e-20 orig_x = x x = x.to(torch.float32) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 436ec53b4..91fccd58d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -809,7 +809,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1106,7 +1106,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index cbde2a2e4..ebef7e977 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py index a6e919e2f..0224c15d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py @@ -150,7 +150,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index b35e56abc..395b07b05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -833,7 +833,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1228,7 +1228,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py index 0582b289f..4675697c1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -178,7 +178,7 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out_fr) lm = self.simple_lm_proj(decoder_out) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -213,7 +213,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index c2d877a93..a431b278d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -822,7 +822,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1118,7 +1118,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1217,7 +1217,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 8e239e322..dc3493425 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -811,7 +811,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1106,7 +1106,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index 8bd00bbef..a8f47d941 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -810,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1124,7 +1124,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1224,7 +1224,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index c7e45564f..e3b8b3725 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py index 5284ed627..ff23725b7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -2708,7 +2708,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index da5e144c9..4c8c239a1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -70,7 +70,7 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -866,7 +866,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1218,7 +1218,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..c0b9113b7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -172,7 +172,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -207,7 +207,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 646f30ca1..0ccef210e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -75,7 +75,7 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -866,7 +866,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1219,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1321,7 +1321,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 1bfd071de..0536e89b3 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -51,7 +51,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW from torch.optim.lr_scheduler import StepLR @@ -809,7 +809,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1092,7 +1092,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1198,7 +1198,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 2ff631914..5da903d38 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -78,7 +78,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1049,7 +1049,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1373,7 +1373,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1474,7 +1474,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index c7dbe1e0a..b0bb7c7fe 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -285,7 +285,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -320,7 +320,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index d345c2931..46df86bf7 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -306,7 +306,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -759,7 +759,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1014,7 +1014,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1353,7 +1353,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1430,7 +1430,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index c074c32ec..71d045ea0 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -79,7 +79,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1101,7 +1101,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast( + with torch.amp.autocast("cuda", enabled=params.use_autocast, dtype=params.dtype ): loss, loss_info = compute_loss( @@ -1438,7 +1438,7 @@ def run(rank, world_size, args): spec_augment=spec_augment, ) - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1540,7 +1540,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast( + with torch.amp.autocast("cuda", enabled=params.use_autocast, dtype=params.dtype ): loss, _ = compute_loss( diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 2a0ae0129..bdfd2175c 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1873,7 +1873,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 3511590da..0207fc26e 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -67,7 +67,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1052,7 +1052,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1397,7 +1397,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1498,7 +1498,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index 8e2dfdd72..6224d136a 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -1916,7 +1916,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index 60112a84e..dfe702d2f 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -46,7 +46,7 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, LRScheduler, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -726,7 +726,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -987,7 +987,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 3f36f229f..53152971d 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -78,7 +78,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1065,7 +1065,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1406,7 +1406,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1507,7 +1507,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 8d7aa8027..a1e77fe0e 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -307,7 +307,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -863,7 +863,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1118,7 +1118,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1457,7 +1457,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1534,7 +1534,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 9ab214e86..592bc0fd4 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -76,7 +76,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -947,7 +947,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1252,7 +1252,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1352,7 +1352,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py index 43865609a..ece7c3df1 100644 --- a/egs/librispeech/ASR/zipformer_lora/zipformer.py +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -1905,7 +1905,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index c1785a328..bed3cfa04 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -744,7 +744,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1037,7 +1037,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1138,7 +1138,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 17daa3c9d..9717d579d 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -816,7 +816,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1109,7 +1109,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1207,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index 2723cc770..340aa4aa2 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -816,7 +816,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1109,7 +1109,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1207,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py index 46a968b69..b23fa32ea 100644 --- a/egs/librispeech/SSL/hubert/model.py +++ b/egs/librispeech/SSL/hubert/model.py @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index f183d90fd..1868bf0a6 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -59,7 +59,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from ssl_datamodule import LibriSpeechDataModule from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.functional import pad from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -644,7 +644,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -945,7 +945,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1036,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 94948695d..97efd983b 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -59,7 +59,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from ssl_datamodule import LibriSpeechDataModule from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.functional import pad from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -644,7 +644,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -945,7 +945,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1036,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index c907b41c5..6bfab9d00 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -1115,7 +1115,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1406,7 +1406,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1504,7 +1504,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py index 46a968b69..b23fa32ea 100644 --- a/egs/librispeech/SSL/zipformer/model.py +++ b/egs/librispeech/SSL/zipformer/model.py @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 937fb382e..767c3bacb 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -58,7 +58,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from ssl_datamodule import LibriSpeechDataModule from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -944,7 +944,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1243,7 +1243,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1334,7 +1334,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py index e9eff3357..7e9ccb51f 100644 --- a/egs/librispeech/SSL/zipformer/zipformer.py +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -1849,7 +1849,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index 82c68803f..fc7728562 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -62,7 +62,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -757,7 +757,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1005,7 +1005,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1076,7 +1076,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py index b276d0587..1c4bd50bf 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -62,7 +62,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -758,7 +758,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1007,7 +1007,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1078,7 +1078,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py index 5485eaf0a..78e3330bd 100755 --- a/egs/libritts/ASR/zipformer/train.py +++ b/egs/libritts/ASR/zipformer/train.py @@ -80,7 +80,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1049,8 +1049,8 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype + with torch.amp.autocast( + "cuda", enabled=params.use_autocast, dtype=params.dtype ): loss, loss_info = compute_loss( params=params, @@ -1378,7 +1378,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1478,8 +1478,8 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast( - enabled=params.use_autocast, dtype=params.dtype + with torch.amp.autocast( + "cuda", enabled=params.use_autocast, dtype=params.dtype ): loss, _ = compute_loss( params=params, diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index f21d494b6..31fc4f126 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -29,7 +29,7 @@ from loss import ( WavReconstructionLoss, ) from torch import nn -from torch.cuda.amp import autocast +from torch.amp import autocast class Encodec(nn.Module): @@ -148,7 +148,7 @@ class Encodec(nn.Module): ) # calculate losses - with autocast(enabled=False): + with autocast("cuda", enabled=False): gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) if self.multi_period_discriminator is not None: @@ -272,7 +272,7 @@ class Encodec(nn.Module): speech_hat.contiguous().detach(), ) # calculate losses - with autocast(enabled=False): + with autocast("cuda", enabled=False): ( disc_stft_real_adv_loss, disc_stft_fake_adv_loss, diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index a4f2eb7ab..31349df43 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -34,7 +34,7 @@ from encodec import Encodec from lhotse.utils import fix_random_seed from scheduler import WarmupCosineLrScheduler from torch import nn -from torch.cuda.amp import GradScaler, autocast +from torch.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -466,7 +466,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): d_weight = train_discriminator( params.lambda_adv, params.cur_epoch, @@ -502,7 +502,7 @@ def train_one_epoch( scaler.scale(disc_loss).backward() scaler.step(optimizer_d) - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): g_weight = train_discriminator( params.lambda_adv, params.cur_epoch, @@ -846,7 +846,7 @@ def scan_pessimistic_batches_for_oom( ) = prepare_input(params, batch, device) try: # for discriminator - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): ( disc_stft_real_adv_loss, disc_stft_fake_adv_loss, @@ -876,7 +876,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): ( commit_loss, gen_stft_adv_loss, @@ -1102,7 +1102,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/libritts/TTS/vits/train.py b/egs/libritts/TTS/vits/train.py index 447fbcf5d..6803d6eb2 100755 --- a/egs/libritts/TTS/vits/train.py +++ b/egs/libritts/TTS/vits/train.py @@ -32,7 +32,7 @@ from lhotse.cut import Cut from lhotse.features.io import KaldiReader from lhotse.utils import fix_random_seed from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast +from torch.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -456,7 +456,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): # forward discriminator loss_d, stats_d = model( text=tokens, @@ -475,7 +475,7 @@ def train_one_epoch( scaler.scale(loss_d).backward() scaler.step(optimizer_d) - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): # forward generator loss_g, stats_g = model( text=tokens, @@ -748,7 +748,7 @@ def scan_pessimistic_batches_for_oom( ) = prepare_input(batch, tokenizer, device, train_speaker_map) try: # for discriminator - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): loss_d, stats_d = model( text=tokens, text_lengths=tokens_lens, @@ -762,7 +762,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): loss_g, stats_g = model( text=tokens, text_lengths=tokens_lens, @@ -922,7 +922,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 853042413..a25cc8723 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -17,7 +17,7 @@ from lhotse.utils import fix_random_seed from model import fix_len_compatibility from models.matcha_tts import MatchaTTS from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast +from torch.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -474,7 +474,7 @@ def train_one_epoch( tokens_lens, ) = prepare_input(batch, tokenizer, device, params) try: - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): losses = get_losses( { "x": tokens, @@ -645,7 +645,7 @@ def run(rank, world_size, args): valid_cuts = ljspeech.valid_cuts() valid_dl = ljspeech.valid_dataloaders(valid_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 184ae79af..e9994319a 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -30,7 +30,7 @@ import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast +from torch.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -396,7 +396,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): # forward discriminator loss_d, stats_d = model( text=tokens, @@ -414,7 +414,7 @@ def train_one_epoch( scaler.scale(loss_d).backward() scaler.step(optimizer_d) - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): # forward generator loss_g, stats_g = model( text=tokens, @@ -673,7 +673,7 @@ def scan_pessimistic_batches_for_oom( ) try: # for discriminator - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): loss_d, stats_d = model( text=tokens, text_lengths=tokens_lens, @@ -686,7 +686,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): loss_g, stats_g = model( text=tokens, text_lengths=tokens_lens, @@ -838,7 +838,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py index 6a067f596..d51ff5f5c 100644 --- a/egs/ljspeech/TTS/vits/utils.py +++ b/egs/ljspeech/TTS/vits/utils.py @@ -23,7 +23,7 @@ import torch import torch.distributed as dist import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index a1fabf9ad..6fd6d219b 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -25,7 +25,7 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) -from torch.cuda.amp import autocast +from torch.amp import autocast from utils import get_segments AVAILABLE_GENERATERS = { @@ -410,7 +410,7 @@ class VITS(nn.Module): p = self.discriminator(speech_) # calculate losses - with autocast(enabled=False): + with autocast("cuda", enabled=False): if not return_sample: mel_loss = self.mel_loss(speech_hat_, speech_) else: @@ -518,7 +518,7 @@ class VITS(nn.Module): p = self.discriminator(speech_) # calculate losses - with autocast(enabled=False): + with autocast("cuda", enabled=False): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py index 730db7718..22249286a 100755 --- a/egs/mdcc/ASR/zipformer/train.py +++ b/egs/mdcc/ASR/zipformer/train.py @@ -68,7 +68,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -906,7 +906,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1197,7 +1197,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1298,7 +1298,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py index 48468cfbd..916ada475 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -751,7 +751,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info, inf_flag = compute_loss( params=params, model=model, @@ -1012,7 +1012,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1115,7 +1115,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _, _ = compute_loss( params=params, model=model, diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py index fe2d950c1..1a11d01af 100755 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -61,7 +61,7 @@ from lhotse.utils import fix_random_seed from multi_dataset import MultiDataset from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.functional import pad as pad_tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -566,7 +566,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -675,7 +675,7 @@ def train_one_epoch( ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -913,7 +913,7 @@ def run(rank, world_size, args): valid_cuts = multi_dataset.dev_cuts() valid_dl = data_module.valid_dataloaders(valid_cuts) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 3dbfc48eb..047253d5b 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -987,7 +987,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1278,7 +1278,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1378,7 +1378,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index 04bb41214..9e64defa3 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -969,7 +969,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1269,7 +1269,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1369,7 +1369,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py b/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py index 072679cfc..c01e4d336 100755 --- a/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py +++ b/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py @@ -67,7 +67,7 @@ from model import Transducer from optim import Eden, ScaledAdam from tokenizer import Tokenizer from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -822,7 +822,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1113,7 +1113,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1213,7 +1213,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/reazonspeech/ASR/zipformer/train.py b/egs/reazonspeech/ASR/zipformer/train.py index 30bd3efba..8829a18ca 100755 --- a/egs/reazonspeech/ASR/zipformer/train.py +++ b/egs/reazonspeech/ASR/zipformer/train.py @@ -74,7 +74,7 @@ from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from tokenizer import Tokenizer from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -945,7 +945,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1235,7 +1235,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1335,7 +1335,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 5f224c984..5de2cf2b0 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -451,7 +451,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -566,7 +566,7 @@ def train_one_epoch( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py index a9146a0fe..1e55ada87 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py @@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -649,7 +649,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -857,7 +857,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -957,7 +957,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py index dfc21c968..319713b02 100755 --- a/egs/spgispeech/ASR/zipformer/train.py +++ b/egs/spgispeech/ASR/zipformer/train.py @@ -74,7 +74,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -946,7 +946,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1217,7 +1217,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1317,7 +1317,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index c0aedd725..c44e30b89 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -69,7 +69,7 @@ from local.tokenize_with_bpe_model import tokenize_by_bpe_model from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -726,7 +726,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) # print(batch["supervisions"]) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -967,7 +967,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1039,7 +1039,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py index 2108266ec..dd9576d99 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -801,7 +801,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1101,7 +1101,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1201,7 +1201,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py index fc3e3b2d9..179dcf14a 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/train.py +++ b/egs/tedlium3/ASR/conformer_ctc2/train.py @@ -57,7 +57,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -710,7 +710,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -941,7 +941,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1011,7 +1011,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/tedlium3/ASR/zipformer/model.py b/egs/tedlium3/ASR/zipformer/model.py index 65b052ab9..0d9b395ed 100644 --- a/egs/tedlium3/ASR/zipformer/model.py +++ b/egs/tedlium3/ASR/zipformer/model.py @@ -173,7 +173,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -209,7 +209,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.cuda.amp.autocast(enabled=False): + with torch.amp.autocast("cuda", enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 14a44efb3..ffe876863 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -73,7 +73,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -911,7 +911,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1160,7 +1160,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1260,7 +1260,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 4686de169..6249640d4 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -31,7 +31,7 @@ import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed from tokenizer import Tokenizer -from torch.cuda.amp import GradScaler, autocast +from torch.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -448,7 +448,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): # forward discriminator loss_d, stats_d = model( text=tokens, @@ -467,7 +467,7 @@ def train_one_epoch( scaler.scale(loss_d).backward() scaler.step(optimizer_d) - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): # forward generator loss_g, stats_g = model( text=tokens, @@ -740,7 +740,7 @@ def scan_pessimistic_batches_for_oom( ) = prepare_input(batch, tokenizer, device, speaker_map) try: # for discriminator - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): loss_d, stats_d = model( text=tokens, text_lengths=tokens_lens, @@ -754,7 +754,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast(enabled=params.use_fp16): + with autocast("cuda", enabled=params.use_fp16): loss_g, stats_g = model( text=tokens, text_lengths=tokens_lens, @@ -910,7 +910,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index c34f1593d..2fd6f6478 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -52,7 +52,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -718,7 +718,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -907,7 +907,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1005,7 +1005,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index 49977e01b..c90f03f08 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -101,7 +101,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -687,7 +687,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -921,7 +921,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1019,7 +1019,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 931e699d9..7b05eca97 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -81,7 +81,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -796,7 +796,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1056,7 +1056,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1158,7 +1158,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py index 4e55fd6a8..c46a4d84c 100644 --- a/egs/wenetspeech/ASR/whisper/train.py +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -61,7 +61,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.functional import pad as pad_tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -513,7 +513,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -621,7 +621,7 @@ def train_one_epoch( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -843,7 +843,7 @@ def run(rank, world_size, args): train_dl = wenetspeech.train_dataloaders(train_cuts) valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts()) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index 25b16f632..b6d55447f 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -71,7 +71,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -910,7 +910,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1201,7 +1201,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1302,7 +1302,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index d19172b38..00db4309d 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -82,7 +82,7 @@ from lhotse.cut import Cut, CutSet from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( @@ -414,7 +414,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -703,7 +703,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 40960c2ae..4dc30ad89 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -73,7 +73,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -967,7 +967,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1252,7 +1252,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1353,7 +1353,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py index e9ec548f3..1c6972e93 100755 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from tokenizer import TextTokenCollater, get_text_token_collater from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tts_datamodule import TtsDataModule @@ -764,7 +764,7 @@ def train_one_epoch( batch_size = len(batch["text"]) try: - with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): + with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled): _, loss, loss_info = compute_loss( params=params, model=model, @@ -897,7 +897,7 @@ def train_one_epoch( # Calculate validation loss in Rank 0 model.eval() logging.info("Computing validation loss") - with torch.cuda.amp.autocast(dtype=dtype): + with torch.amp.autocast("cuda", dtype=dtype): valid_info = compute_validation_loss( params=params, model=model, @@ -1102,7 +1102,9 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) + scaler = GradScaler( + "cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 + ) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1196,7 +1198,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(dtype=dtype): + with torch.amp.autocast("cuda", dtype=dtype): _, loss, _ = compute_loss( params=params, model=model, diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py index a6fa46b17..5c3000a57 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -814,7 +814,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1072,7 +1072,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler(enabled=params.use_fp16) + scaler = GradScaler("cuda", enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1141,7 +1141,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index dd72551d9..a1b3be246 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -785,7 +785,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1074,7 +1074,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1174,7 +1174,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index d31ce1301..b3a0fb865 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler from torch import Tensor -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 0178b80bf..257cdb09a 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -401,7 +401,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -470,7 +470,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py index c36abfcdf..6faa63484 100644 --- a/icefall/transformer_lm/train.py +++ b/icefall/transformer_lm/train.py @@ -341,7 +341,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -403,7 +403,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): + with torch.amp.autocast("cuda", enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, From d4d4f281ecefad7b779def552975d2881620a724 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 18 Dec 2024 16:49:57 +0800 Subject: [PATCH 49/59] Revert "Replace deprecated pytorch methods (#1814)" (#1841) This reverts commit 3e4da5f78160d3dba3bdf97968bd7ceb8c11631f. --- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless3/model.py | 4 ++-- .../ASR/pruned_transducer_stateless3/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_bbpe/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- egs/aishell/ASR/whisper/train.py | 8 ++++---- egs/aishell/ASR/zipformer/train.py | 8 ++++---- egs/aishell/ASR/zipformer/train_bbpe.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR_v2/pruned_transducer_stateless7/train.py | 8 ++++---- egs/ami/ASR/pruned_transducer_stateless7/train.py | 8 ++++---- egs/ami/SURT/dprnn_zipformer/train.py | 6 +++--- egs/ami/SURT/dprnn_zipformer/train_adapt.py | 6 +++--- egs/audioset/AT/zipformer/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../finetune.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- egs/commonvoice/ASR/zipformer/train.py | 8 ++++---- egs/commonvoice/ASR/zipformer/train_char.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- egs/gigaspeech/ASR/zipformer/train.py | 8 ++++---- egs/gigaspeech/KWS/zipformer/finetune.py | 6 +++--- egs/gigaspeech/KWS/zipformer/train.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- egs/ksponspeech/ASR/zipformer/train.py | 8 ++++---- egs/libricss/SURT/dprnn_zipformer/model.py | 4 ++-- egs/libricss/SURT/dprnn_zipformer/scaling.py | 6 +++--- egs/libricss/SURT/dprnn_zipformer/train.py | 6 +++--- egs/libricss/SURT/dprnn_zipformer/train_adapt.py | 6 +++--- egs/libriheavy/ASR/zipformer/train.py | 8 ++++---- .../ASR/zipformer_prompt_asr/model_baseline.py | 4 ++-- .../ASR/zipformer_prompt_asr/model_with_BERT.py | 4 ++-- egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py | 10 +++++----- .../ASR/zipformer_prompt_asr/train_baseline.py | 8 ++++---- .../ASR/zipformer_prompt_asr/train_bert_encoder.py | 8 ++++---- egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py | 4 ++-- egs/librilight/SSL/zipformer/finetune.py | 8 ++++---- egs/librilight/SSL/zipformer/pretrain.py | 8 ++++---- egs/librispeech/ASR/conformer_ctc2/train.py | 8 ++++---- egs/librispeech/ASR/conformer_ctc3/train.py | 8 ++++---- .../ASR/conv_emformer_transducer_stateless/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../ASR/conv_emformer_transducer_stateless2/train.py | 8 ++++---- .../ASR/lstm_transducer_stateless/model.py | 4 ++-- .../ASR/lstm_transducer_stateless/train.py | 8 ++++---- .../ASR/lstm_transducer_stateless2/model.py | 4 ++-- .../ASR/lstm_transducer_stateless2/train.py | 8 ++++---- .../ASR/lstm_transducer_stateless3/train.py | 8 ++++---- egs/librispeech/ASR/pruned2_knowledge/model.py | 4 ++-- egs/librispeech/ASR/pruned2_knowledge/sampling.py | 6 +++--- egs/librispeech/ASR/pruned2_knowledge/train.py | 8 ++++---- .../ASR/pruned_stateless_emformer_rnnt2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/model.py | 4 ++-- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless3/model.py | 4 ++-- .../ASR/pruned_transducer_stateless3/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless4/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless6/model.py | 4 ++-- .../ASR/pruned_transducer_stateless6/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/finetune.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/model.py | 4 ++-- .../ASR/pruned_transducer_stateless7/scaling.py | 6 +++--- .../ASR/pruned_transducer_stateless7/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/zipformer.py | 2 +- .../ASR/pruned_transducer_stateless7_ctc/model.py | 4 ++-- .../ASR/pruned_transducer_stateless7_ctc/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_ctc_bs/model.py | 4 ++-- .../ASR/pruned_transducer_stateless7_ctc_bs/train.py | 8 ++++---- .../do_not_use_it_directly.py | 8 ++++---- .../pruned_transducer_stateless7_streaming/train.py | 8 ++++---- .../zipformer.py | 2 +- .../zipformer_for_ncnn_export_only.py | 2 +- .../train.py | 8 ++++---- .../ASR/pruned_transducer_stateless8/model.py | 4 ++-- .../ASR/pruned_transducer_stateless8/train.py | 8 ++++---- egs/librispeech/ASR/tiny_transducer_ctc/train.py | 8 ++++---- egs/librispeech/ASR/zipformer/finetune.py | 8 ++++---- egs/librispeech/ASR/zipformer/model.py | 4 ++-- egs/librispeech/ASR/zipformer/scaling.py | 10 +++++----- egs/librispeech/ASR/zipformer/train.py | 8 ++++---- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- egs/librispeech/ASR/zipformer_adapter/train.py | 8 ++++---- egs/librispeech/ASR/zipformer_adapter/zipformer.py | 2 +- egs/librispeech/ASR/zipformer_ctc/train.py | 6 +++--- egs/librispeech/ASR/zipformer_lora/finetune.py | 8 ++++---- egs/librispeech/ASR/zipformer_lora/scaling.py | 10 +++++----- egs/librispeech/ASR/zipformer_lora/train.py | 8 ++++---- egs/librispeech/ASR/zipformer_lora/zipformer.py | 2 +- egs/librispeech/ASR/zipformer_mmi/train.py | 8 ++++---- egs/librispeech/SSL/hubert/finetune.py | 8 ++++---- egs/librispeech/SSL/hubert/finetune_ce.py | 8 ++++---- egs/librispeech/SSL/hubert/model.py | 4 ++-- egs/librispeech/SSL/hubert/pretrain.py | 8 ++++---- egs/librispeech/SSL/hubert/pretrain_ce.py | 8 ++++---- egs/librispeech/SSL/zipformer/finetune.py | 8 ++++---- egs/librispeech/SSL/zipformer/model.py | 4 ++-- egs/librispeech/SSL/zipformer/pretrain.py | 8 ++++---- egs/librispeech/SSL/zipformer/zipformer.py | 2 +- egs/librispeech/WSASR/conformer_ctc2/train.py | 8 ++++---- egs/librispeech/WSASR/conformer_ctc2/train_phone.py | 8 ++++---- egs/libritts/ASR/zipformer/train.py | 12 ++++++------ egs/libritts/CODEC/encodec/encodec.py | 6 +++--- egs/libritts/CODEC/encodec/train.py | 12 ++++++------ egs/libritts/TTS/vits/train.py | 12 ++++++------ egs/ljspeech/TTS/matcha/train.py | 6 +++--- egs/ljspeech/TTS/vits/train.py | 12 ++++++------ egs/ljspeech/TTS/vits/utils.py | 2 +- egs/ljspeech/TTS/vits/vits.py | 6 +++--- egs/mdcc/ASR/zipformer/train.py | 8 ++++---- egs/mgb2/ASR/pruned_transducer_stateless5/train.py | 8 ++++---- egs/multi_zh-hans/ASR/whisper/train.py | 8 ++++---- egs/multi_zh-hans/ASR/zipformer/train.py | 8 ++++---- egs/multi_zh_en/ASR/zipformer/train.py | 8 ++++---- .../ASR/zipformer/do_not_use_it_directly.py | 8 ++++---- egs/reazonspeech/ASR/zipformer/train.py | 8 ++++---- egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py | 4 ++-- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- egs/spgispeech/ASR/zipformer/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7_bbpe/train.py | 8 ++++---- egs/tedlium3/ASR/conformer_ctc2/train.py | 8 ++++---- egs/tedlium3/ASR/zipformer/model.py | 4 ++-- egs/tedlium3/ASR/zipformer/train.py | 8 ++++---- egs/vctk/TTS/vits/train.py | 12 ++++++------ .../ASR/pruned_transducer_stateless2/finetune.py | 8 ++++---- .../ASR/pruned_transducer_stateless2/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- egs/wenetspeech/ASR/whisper/train.py | 8 ++++---- egs/wenetspeech/ASR/zipformer/train.py | 8 ++++---- egs/wenetspeech/KWS/zipformer/finetune.py | 6 +++--- egs/wenetspeech/KWS/zipformer/train.py | 8 ++++---- egs/wenetspeech4tts/TTS/valle/train.py | 12 +++++------- .../ASR/pruned_transducer_stateless5/train.py | 8 ++++---- .../ASR/pruned_transducer_stateless7/train.py | 8 ++++---- icefall/checkpoint.py | 2 +- icefall/rnn_lm/train.py | 4 ++-- icefall/transformer_lm/train.py | 4 ++-- 147 files changed, 518 insertions(+), 520 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index 9088378fa..fa809b768 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -638,7 +638,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -843,7 +843,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -912,7 +912,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless2/train.py b/egs/aishell/ASR/pruned_transducer_stateless2/train.py index dda098e99..60f014c48 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless2/train.py @@ -60,7 +60,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -688,7 +688,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -888,7 +888,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -989,7 +989,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py index cafc9d1bb..a4dda0d6d 100644 --- a/egs/aishell/ASR/pruned_transducer_stateless3/model.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -184,7 +184,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -219,7 +219,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py index bf60c4fad..7c23041ca 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -79,7 +79,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -797,7 +797,7 @@ def train_one_epoch( aishell = is_aishell(batch["supervisions"]["cut"][0]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1096,7 +1096,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py index 9a9d92c20..058d0ff6b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/do_not_use_it_directly.py @@ -74,7 +74,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -812,7 +812,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1107,7 +1107,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/train.py b/egs/aishell/ASR/pruned_transducer_stateless7/train.py index ede2bd3e5..2dc835f3b 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/train.py @@ -70,7 +70,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -809,7 +809,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1107,7 +1107,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py index be48d6dde..811269989 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -802,7 +802,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1102,7 +1102,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index e3387e670..6653d9d9c 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -813,7 +813,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1105,7 +1105,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1205,7 +1205,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py index cba312214..f3b0f1e11 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py @@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -812,7 +812,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1104,7 +1104,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1202,7 +1202,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index e84dcf156..d77f8c270 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -62,7 +62,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.functional import pad as pad_tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -514,7 +514,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -608,7 +608,7 @@ def train_one_epoch( ) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -812,7 +812,7 @@ def run(rank, world_size, args): train_dl = aishell.train_dataloaders(aishell.train_cuts()) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/aishell/ASR/zipformer/train.py b/egs/aishell/ASR/zipformer/train.py index ab568b20f..cd253c597 100755 --- a/egs/aishell/ASR/zipformer/train.py +++ b/egs/aishell/ASR/zipformer/train.py @@ -71,7 +71,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -910,7 +910,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1201,7 +1201,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1302,7 +1302,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell/ASR/zipformer/train_bbpe.py b/egs/aishell/ASR/zipformer/train_bbpe.py index 2dac0cc64..46a5506db 100755 --- a/egs/aishell/ASR/zipformer/train_bbpe.py +++ b/egs/aishell/ASR/zipformer/train_bbpe.py @@ -61,7 +61,7 @@ from lhotse.cut import Cut from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( @@ -495,7 +495,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -795,7 +795,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -895,7 +895,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py index 772d9e6bf..8c7448d4c 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/train.py @@ -75,7 +75,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -734,7 +734,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -963,7 +963,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1062,7 +1062,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py index 0eb9271f5..a354f761e 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/train.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,7 @@ from local.text_normalize import text_normalize from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -727,7 +727,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) # print(batch["supervisions"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -963,7 +963,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1034,7 +1034,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py index 2b1b6f9b4..30154291d 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -638,7 +638,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -843,7 +843,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -912,7 +912,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py index e321deeb1..30879d8d2 100755 --- a/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py +++ b/egs/alimeeting/ASR_v2/pruned_transducer_stateless7/train.py @@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -782,7 +782,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1031,7 +1031,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1127,7 +1127,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index 97ebc5bcf..d62cdadb7 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -773,7 +773,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1034,7 +1034,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1134,7 +1134,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ami/SURT/dprnn_zipformer/train.py b/egs/ami/SURT/dprnn_zipformer/train.py index 9e77c0527..adc6a8495 100755 --- a/egs/ami/SURT/dprnn_zipformer/train.py +++ b/egs/ami/SURT/dprnn_zipformer/train.py @@ -61,7 +61,7 @@ from model import SURT from optim import Eden, ScaledAdam from scaling import ScaledLinear, ScaledLSTM from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -1067,7 +1067,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1314,7 +1314,7 @@ def run(rank, world_size, args): ) valid_dl = ami.valid_dataloaders(dev_cuts) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/ami/SURT/dprnn_zipformer/train_adapt.py b/egs/ami/SURT/dprnn_zipformer/train_adapt.py index 0647a7c78..ac5b0dadc 100755 --- a/egs/ami/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/ami/SURT/dprnn_zipformer/train_adapt.py @@ -61,7 +61,7 @@ from model import SURT from optim import Eden, ScaledAdam from scaling import ScaledLinear, ScaledLSTM from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -1058,7 +1058,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1305,7 +1305,7 @@ def run(rank, world_size, args): ) valid_dl = ami.valid_dataloaders(dev_cuts) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 9532ed906..67c703364 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -53,7 +53,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -799,7 +799,7 @@ def train_one_epoch( num_samples += batch_size try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1057,7 +1057,7 @@ def run(rank, world_size, args): valid_cuts = audioset.audioset_eval_cuts() valid_dl = audioset.valid_dataloaders(valid_cuts) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1148,7 +1148,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py index 486ab73df..5e98084ec 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7/train.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -825,7 +825,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1120,7 +1120,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1220,7 +1220,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index fa241abe7..aefe88f3f 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -818,7 +818,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1109,7 +1109,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1209,7 +1209,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py index 8905dc617..976004eca 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/finetune.py @@ -68,7 +68,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -895,7 +895,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1193,7 +1193,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1293,7 +1293,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py index 8260c4985..67e1a8133 100755 --- a/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/commonvoice/ASR/pruned_transducer_stateless7_streaming/train.py @@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -840,7 +840,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1137,7 +1137,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1237,7 +1237,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train.py b/egs/commonvoice/ASR/zipformer/train.py index c0219df19..271014db0 100755 --- a/egs/commonvoice/ASR/zipformer/train.py +++ b/egs/commonvoice/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -969,7 +969,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1265,7 +1265,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1365,7 +1365,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/commonvoice/ASR/zipformer/train_char.py b/egs/commonvoice/ASR/zipformer/train_char.py index 639e1067a..0aa7856cc 100755 --- a/egs/commonvoice/ASR/zipformer/train_char.py +++ b/egs/commonvoice/ASR/zipformer/train_char.py @@ -67,7 +67,7 @@ from lhotse.cut import Cut from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( @@ -604,7 +604,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -784,7 +784,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, @@ -979,7 +979,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index 661bfa6ca..6d256308c 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -67,7 +67,7 @@ from model import Transducer from optim import Eden, ScaledAdam from tokenizer import Tokenizer from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -839,7 +839,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1146,7 +1146,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1246,7 +1246,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py index 8f07fc42f..ef7ea9013 100755 --- a/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/csj/ASR/pruned_transducer_stateless7_streaming/train.py @@ -67,7 +67,7 @@ from model import Transducer from optim import Eden, ScaledAdam from tokenizer import Tokenizer from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -838,7 +838,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1145,7 +1145,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1245,7 +1245,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index e0e11fc70..a7772b62f 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -675,7 +675,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -873,7 +873,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -944,7 +944,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index 5092ef8cb..4c122effe 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -958,7 +958,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1217,7 +1217,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1317,7 +1317,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py index 49e8aef1a..a7ba56127 100755 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -73,7 +73,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( @@ -291,7 +291,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -570,7 +570,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index f2283cb03..39d8fc6cd 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -961,7 +961,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1220,7 +1220,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py index 30d9f0e51..bf50bf5ea 100755 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -61,7 +61,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -805,7 +805,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1096,7 +1096,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1196,7 +1196,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/ksponspeech/ASR/zipformer/train.py b/egs/ksponspeech/ASR/zipformer/train.py index 5f6ee7cca..485ea69c9 100755 --- a/egs/ksponspeech/ASR/zipformer/train.py +++ b/egs/ksponspeech/ASR/zipformer/train.py @@ -70,7 +70,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -942,7 +942,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1233,7 +1233,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1333,7 +1333,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libricss/SURT/dprnn_zipformer/model.py b/egs/libricss/SURT/dprnn_zipformer/model.py index 0e88357d1..688e1e78d 100644 --- a/egs/libricss/SURT/dprnn_zipformer/model.py +++ b/egs/libricss/SURT/dprnn_zipformer/model.py @@ -140,7 +140,7 @@ class SURT(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -175,7 +175,7 @@ class SURT(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling.py b/egs/libricss/SURT/dprnn_zipformer/scaling.py index d46cb224e..4040a7b89 100644 --- a/egs/libricss/SURT/dprnn_zipformer/scaling.py +++ b/egs/libricss/SURT/dprnn_zipformer/scaling.py @@ -287,7 +287,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -1065,7 +1065,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor): (x_orig,) = ctx.saved_tensors with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1263,7 +1263,7 @@ class MaxEig(torch.nn.Module): ): return _no_op(x) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): eps = 1.0e-20 orig_x = x x = x.to(torch.float32) diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py index 33ea7c5a6..148cafd4b 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train.py +++ b/egs/libricss/SURT/dprnn_zipformer/train.py @@ -69,7 +69,7 @@ from model import SURT from optim import Eden, ScaledAdam from scaling import ScaledLSTM from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -1096,7 +1096,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1343,7 +1343,7 @@ def run(rank, world_size, args): train_dl_ov40 = libricss.train_dataloaders(train_cuts_ov40) valid_dl = libricss.valid_dataloaders(dev_cuts) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py index 82b61baa0..8c37430ec 100755 --- a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py +++ b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py @@ -67,7 +67,7 @@ from model import SURT from optim import Eden, ScaledAdam from scaling import ScaledLinear, ScaledLSTM from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -985,7 +985,7 @@ def train_one_epoch( batch_size = batch["inputs"].shape[0] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1237,7 +1237,7 @@ def run(rank, world_size, args): ) valid_dl = libricss.valid_dataloaders(dev_cuts) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/libriheavy/ASR/zipformer/train.py b/egs/libriheavy/ASR/zipformer/train.py index 524273ec5..357e8a827 100644 --- a/egs/libriheavy/ASR/zipformer/train.py +++ b/egs/libriheavy/ASR/zipformer/train.py @@ -78,7 +78,7 @@ from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from text_normalization import remove_punc_to_upper from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -958,7 +958,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1268,7 +1268,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1367,7 +1367,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py index 66328bb89..77b4057c4 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_baseline.py @@ -186,7 +186,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -221,7 +221,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py index 80fbf09f0..21c7b4fac 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py @@ -245,7 +245,7 @@ class PromptedTransducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -287,7 +287,7 @@ class PromptedTransducer(nn.Module): logits = self.joiner(am_pruned, lm_pruned, context=context, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py index a260d828e..0e6764ba0 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/scaling.py @@ -271,7 +271,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -685,7 +685,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -940,7 +940,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1280,7 +1280,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1351,7 +1351,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py index bfca5a0db..93f7e1248 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_baseline.py @@ -89,7 +89,7 @@ from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from text_normalization import train_text_normalization, upper_only_alpha from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -975,7 +975,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1271,7 +1271,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1371,7 +1371,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index 36c6d6464..2a2c206aa 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -103,7 +103,7 @@ from text_normalization import ( upper_only_alpha, ) from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1321,7 +1321,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1647,7 +1647,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1749,7 +1749,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py index 405c95acc..d1cf90ffb 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/zipformer.py @@ -1561,7 +1561,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) @@ -1844,7 +1844,7 @@ class MultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librilight/SSL/zipformer/finetune.py b/egs/librilight/SSL/zipformer/finetune.py index 568096c6a..50dbd5f2d 100644 --- a/egs/librilight/SSL/zipformer/finetune.py +++ b/egs/librilight/SSL/zipformer/finetune.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -1116,7 +1116,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1407,7 +1407,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1505,7 +1505,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librilight/SSL/zipformer/pretrain.py b/egs/librilight/SSL/zipformer/pretrain.py index 019f77ea3..5728dbe75 100644 --- a/egs/librilight/SSL/zipformer/pretrain.py +++ b/egs/librilight/SSL/zipformer/pretrain.py @@ -57,7 +57,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from ssl_datamodule import LibriLightDataModule from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -936,7 +936,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1229,7 +1229,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index b0b5da1c0..c4a13b101 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -65,7 +65,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -676,7 +676,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -965,7 +965,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1036,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index 7e819a2d8..a2f1125ca 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -76,7 +76,7 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -743,7 +743,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1004,7 +1004,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1073,7 +1073,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index 130a7c97f..ca21bd6bf 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -772,7 +772,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1002,7 +1002,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1071,7 +1071,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py index 16ae4e4e2..d614f0914 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/do_not_use_it_directly.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -774,7 +774,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1003,7 +1003,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1074,7 +1074,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 28d094a76..23ddb6bec 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -772,7 +772,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1001,7 +1001,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1072,7 +1072,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index 1ec9a8fc6..e7bad7ed8 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -156,7 +156,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -192,7 +192,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index 1e50ce090..feb81d500 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -66,7 +66,7 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -763,7 +763,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1023,7 +1023,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1092,7 +1092,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index a758c550d..4957d14b1 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 4d4f3e132..4fc4fa7f8 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -74,7 +74,7 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -848,7 +848,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1176,7 +1176,7 @@ def run(rank, world_size, args): else: logging.info("Skip scan_pessimistic_batches_for_oom") - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1247,7 +1247,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index ae4cd1c6a..2c1cef3a3 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -66,7 +66,7 @@ from lstm import RNN from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -793,7 +793,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1067,7 +1067,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1136,7 +1136,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index 2ffea06e7..ca8c28af1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -141,7 +141,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -176,7 +176,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 3d2fdd6d8..5b595c76c 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -10,7 +10,7 @@ from typing import Optional, Tuple import torch from scaling import ScaledLinear from torch import Tensor, nn -from torch.amp import GradScaler, custom_bwd, custom_fwd +from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the @@ -330,14 +330,14 @@ def _test_knowledge_base_lookup_autocast(): optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) - scaler = GradScaler("cuda", enabled=True) + scaler = GradScaler(enabled=True) start = timeit.default_timer() for epoch in range(150): for n, (x, y) in enumerate(train_pairs): y_out = m(x) - with torch.amp.autocast("cuda", enabled=True): + with torch.cuda.amp.autocast(enabled=True): loss = ((y_out - y) ** 2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index 8c117dd60..931341cc4 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -650,7 +650,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -868,7 +868,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -937,7 +937,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index b25a84a6b..2b872f1d5 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -55,7 +55,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from noam import Noam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -693,7 +693,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -939,7 +939,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1004,7 +1004,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 59ed8310c..272d06c37 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -157,7 +157,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -193,7 +193,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index e86ec8052..6c19f2cb0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -78,7 +78,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -759,7 +759,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1000,7 +1000,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1067,7 +1067,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 0495c8a29..d45f6dadc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -195,7 +195,7 @@ class Transducer(nn.Module): lm = simple_lm_proj(decoder_out) am = simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -231,7 +231,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 8ef207518..fdafa5a87 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -74,7 +74,7 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -827,7 +827,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1126,7 +1126,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 0 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1195,7 +1195,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index b6682908b..875b03f7f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -789,7 +789,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1047,7 +1047,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1116,7 +1116,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 2b559a27c..66dc5f991 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -814,7 +814,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1078,7 +1078,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1147,7 +1147,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index 20b730a08..daadb70c9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -185,7 +185,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -220,7 +220,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 93663505a..8f033cb9a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -80,7 +80,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -781,7 +781,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1039,7 +1039,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1108,7 +1108,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index d29010a23..e7546ec45 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -903,7 +903,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1219,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1319,7 +1319,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 49076b96f..add0e6a18 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -150,7 +150,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 16d86fe2d..30a737061 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -289,7 +289,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -669,7 +669,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): def backward(ctx, x_grad: Tensor): (x_orig,) = ctx.saved_tensors with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -867,7 +867,7 @@ class MaxEig(torch.nn.Module): ): return _no_op(x) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): eps = 1.0e-20 orig_x = x x = x.to(torch.float32) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 91fccd58d..436ec53b4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -809,7 +809,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1106,7 +1106,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index ebef7e977..cbde2a2e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -1421,7 +1421,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py index 0224c15d7..a6e919e2f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py @@ -150,7 +150,7 @@ class Transducer(nn.Module): lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -185,7 +185,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index 395b07b05..b35e56abc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -833,7 +833,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1128,7 +1128,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1228,7 +1228,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py index 4675697c1..0582b289f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/model.py @@ -178,7 +178,7 @@ class Transducer(nn.Module): am = self.simple_am_proj(encoder_out_fr) lm = self.simple_lm_proj(decoder_out) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -213,7 +213,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py index a431b278d..c2d877a93 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/train.py @@ -63,7 +63,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -822,7 +822,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1118,7 +1118,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1217,7 +1217,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py index dc3493425..8e239e322 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -811,7 +811,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1106,7 +1106,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1206,7 +1206,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py index a8f47d941..8bd00bbef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -810,7 +810,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1124,7 +1124,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1224,7 +1224,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py index e3b8b3725..c7e45564f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -2408,7 +2408,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py index ff23725b7..5284ed627 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py @@ -2708,7 +2708,7 @@ class RelPositionMultiheadAttention(nn.Module): bsz = n // num_heads with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) attn_weights_entropy = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py index 4c8c239a1..da5e144c9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming_multi/train.py @@ -70,7 +70,7 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -866,7 +866,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1218,7 +1218,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1320,7 +1320,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index c0b9113b7..39a360796 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -172,7 +172,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -207,7 +207,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 0ccef210e..646f30ca1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -75,7 +75,7 @@ from librispeech import LibriSpeech from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -866,7 +866,7 @@ def train_one_epoch( libri = is_libri(batch["supervisions"]["cut"][0]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1219,7 +1219,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1321,7 +1321,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 0536e89b3..1bfd071de 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -51,7 +51,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import AdamW from torch.optim.lr_scheduler import StepLR @@ -809,7 +809,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1092,7 +1092,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1198,7 +1198,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/finetune.py b/egs/librispeech/ASR/zipformer/finetune.py index 5da903d38..2ff631914 100755 --- a/egs/librispeech/ASR/zipformer/finetune.py +++ b/egs/librispeech/ASR/zipformer/finetune.py @@ -78,7 +78,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1049,7 +1049,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1373,7 +1373,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1474,7 +1474,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index b0bb7c7fe..c7dbe1e0a 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -285,7 +285,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -320,7 +320,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 46df86bf7..d345c2931 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -306,7 +306,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -759,7 +759,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1014,7 +1014,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1353,7 +1353,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1430,7 +1430,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 71d045ea0..c074c32ec 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -79,7 +79,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1101,7 +1101,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", + with torch.cuda.amp.autocast( enabled=params.use_autocast, dtype=params.dtype ): loss, loss_info = compute_loss( @@ -1438,7 +1438,7 @@ def run(rank, world_size, args): spec_augment=spec_augment, ) - scaler = GradScaler("cuda", enabled=params.use_autocast, init_scale=1.0) + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1540,7 +1540,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", + with torch.cuda.amp.autocast( enabled=params.use_autocast, dtype=params.dtype ): loss, _ = compute_loss( diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index bdfd2175c..2a0ae0129 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1873,7 +1873,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 0207fc26e..3511590da 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -67,7 +67,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1052,7 +1052,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1397,7 +1397,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1498,7 +1498,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_adapter/zipformer.py b/egs/librispeech/ASR/zipformer_adapter/zipformer.py index 6224d136a..8e2dfdd72 100644 --- a/egs/librispeech/ASR/zipformer_adapter/zipformer.py +++ b/egs/librispeech/ASR/zipformer_adapter/zipformer.py @@ -1916,7 +1916,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_ctc/train.py b/egs/librispeech/ASR/zipformer_ctc/train.py index dfe702d2f..60112a84e 100755 --- a/egs/librispeech/ASR/zipformer_ctc/train.py +++ b/egs/librispeech/ASR/zipformer_ctc/train.py @@ -46,7 +46,7 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, LRScheduler, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -726,7 +726,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -987,7 +987,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/librispeech/ASR/zipformer_lora/finetune.py b/egs/librispeech/ASR/zipformer_lora/finetune.py index 53152971d..3f36f229f 100755 --- a/egs/librispeech/ASR/zipformer_lora/finetune.py +++ b/egs/librispeech/ASR/zipformer_lora/finetune.py @@ -78,7 +78,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1065,7 +1065,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1406,7 +1406,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1507,7 +1507,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index a1e77fe0e..8d7aa8027 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -307,7 +307,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): (ans,) = ctx.saved_tensors - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) x_grad = ans_grad * ans @@ -863,7 +863,7 @@ class BalancerFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): x = x.to(torch.float32) x = x.detach() x.requires_grad = True @@ -1118,7 +1118,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function): try: with torch.enable_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() x_detached.requires_grad = True @@ -1457,7 +1457,7 @@ class SwooshLFunction(torch.autograd.Function): coeff = -0.08 - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True @@ -1534,7 +1534,7 @@ class SwooshRFunction(torch.autograd.Function): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): with torch.enable_grad(): x = x.detach() x.requires_grad = True diff --git a/egs/librispeech/ASR/zipformer_lora/train.py b/egs/librispeech/ASR/zipformer_lora/train.py index 592bc0fd4..9ab214e86 100755 --- a/egs/librispeech/ASR/zipformer_lora/train.py +++ b/egs/librispeech/ASR/zipformer_lora/train.py @@ -76,7 +76,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -947,7 +947,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1252,7 +1252,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1352,7 +1352,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py index ece7c3df1..43865609a 100644 --- a/egs/librispeech/ASR/zipformer_lora/zipformer.py +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -1905,7 +1905,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/ASR/zipformer_mmi/train.py b/egs/librispeech/ASR/zipformer_mmi/train.py index bed3cfa04..c1785a328 100755 --- a/egs/librispeech/ASR/zipformer_mmi/train.py +++ b/egs/librispeech/ASR/zipformer_mmi/train.py @@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed from model import CTCModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -744,7 +744,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1037,7 +1037,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1138,7 +1138,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune.py b/egs/librispeech/SSL/hubert/finetune.py index 9717d579d..17daa3c9d 100644 --- a/egs/librispeech/SSL/hubert/finetune.py +++ b/egs/librispeech/SSL/hubert/finetune.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -816,7 +816,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1109,7 +1109,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1207,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/finetune_ce.py b/egs/librispeech/SSL/hubert/finetune_ce.py index 340aa4aa2..2723cc770 100644 --- a/egs/librispeech/SSL/hubert/finetune_ce.py +++ b/egs/librispeech/SSL/hubert/finetune_ce.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -816,7 +816,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1109,7 +1109,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1207,7 +1207,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/model.py b/egs/librispeech/SSL/hubert/model.py index b23fa32ea..46a968b69 100644 --- a/egs/librispeech/SSL/hubert/model.py +++ b/egs/librispeech/SSL/hubert/model.py @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/hubert/pretrain.py b/egs/librispeech/SSL/hubert/pretrain.py index 1868bf0a6..f183d90fd 100644 --- a/egs/librispeech/SSL/hubert/pretrain.py +++ b/egs/librispeech/SSL/hubert/pretrain.py @@ -59,7 +59,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from ssl_datamodule import LibriSpeechDataModule from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.functional import pad from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -644,7 +644,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -945,7 +945,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1036,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/hubert/pretrain_ce.py b/egs/librispeech/SSL/hubert/pretrain_ce.py index 97efd983b..94948695d 100644 --- a/egs/librispeech/SSL/hubert/pretrain_ce.py +++ b/egs/librispeech/SSL/hubert/pretrain_ce.py @@ -59,7 +59,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from ssl_datamodule import LibriSpeechDataModule from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.functional import pad from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -644,7 +644,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -945,7 +945,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1036,7 +1036,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/finetune.py b/egs/librispeech/SSL/zipformer/finetune.py index 6bfab9d00..c907b41c5 100644 --- a/egs/librispeech/SSL/zipformer/finetune.py +++ b/egs/librispeech/SSL/zipformer/finetune.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -1115,7 +1115,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1406,7 +1406,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1504,7 +1504,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/model.py b/egs/librispeech/SSL/zipformer/model.py index b23fa32ea..46a968b69 100644 --- a/egs/librispeech/SSL/zipformer/model.py +++ b/egs/librispeech/SSL/zipformer/model.py @@ -221,7 +221,7 @@ class AsrModel(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -256,7 +256,7 @@ class AsrModel(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/librispeech/SSL/zipformer/pretrain.py b/egs/librispeech/SSL/zipformer/pretrain.py index 767c3bacb..937fb382e 100644 --- a/egs/librispeech/SSL/zipformer/pretrain.py +++ b/egs/librispeech/SSL/zipformer/pretrain.py @@ -58,7 +58,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from ssl_datamodule import LibriSpeechDataModule from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -944,7 +944,7 @@ def train_one_epoch( batch_size = batch["kmeans"].shape[0] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1243,7 +1243,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1334,7 +1334,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py index 7e9ccb51f..e9eff3357 100644 --- a/egs/librispeech/SSL/zipformer/zipformer.py +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -1849,7 +1849,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape with torch.no_grad(): - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_weights_entropy = ( -((attn_weights + 1.0e-20).log() * attn_weights) diff --git a/egs/librispeech/WSASR/conformer_ctc2/train.py b/egs/librispeech/WSASR/conformer_ctc2/train.py index fc7728562..82c68803f 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train.py @@ -62,7 +62,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -757,7 +757,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1005,7 +1005,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1076,7 +1076,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py index 1c4bd50bf..b276d0587 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -62,7 +62,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -758,7 +758,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1007,7 +1007,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1078,7 +1078,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/libritts/ASR/zipformer/train.py b/egs/libritts/ASR/zipformer/train.py index 78e3330bd..5485eaf0a 100755 --- a/egs/libritts/ASR/zipformer/train.py +++ b/egs/libritts/ASR/zipformer/train.py @@ -80,7 +80,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -1049,8 +1049,8 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast( - "cuda", enabled=params.use_autocast, dtype=params.dtype + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype ): loss, loss_info = compute_loss( params=params, @@ -1378,7 +1378,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_autocast, init_scale=1.0) + scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1478,8 +1478,8 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast( - "cuda", enabled=params.use_autocast, dtype=params.dtype + with torch.cuda.amp.autocast( + enabled=params.use_autocast, dtype=params.dtype ): loss, _ = compute_loss( params=params, diff --git a/egs/libritts/CODEC/encodec/encodec.py b/egs/libritts/CODEC/encodec/encodec.py index 31fc4f126..f21d494b6 100644 --- a/egs/libritts/CODEC/encodec/encodec.py +++ b/egs/libritts/CODEC/encodec/encodec.py @@ -29,7 +29,7 @@ from loss import ( WavReconstructionLoss, ) from torch import nn -from torch.amp import autocast +from torch.cuda.amp import autocast class Encodec(nn.Module): @@ -148,7 +148,7 @@ class Encodec(nn.Module): ) # calculate losses - with autocast("cuda", enabled=False): + with autocast(enabled=False): gen_stft_adv_loss = self.generator_adversarial_loss(outputs=y_hat) if self.multi_period_discriminator is not None: @@ -272,7 +272,7 @@ class Encodec(nn.Module): speech_hat.contiguous().detach(), ) # calculate losses - with autocast("cuda", enabled=False): + with autocast(enabled=False): ( disc_stft_real_adv_loss, disc_stft_fake_adv_loss, diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index 31349df43..a4f2eb7ab 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -34,7 +34,7 @@ from encodec import Encodec from lhotse.utils import fix_random_seed from scheduler import WarmupCosineLrScheduler from torch import nn -from torch.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -466,7 +466,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): d_weight = train_discriminator( params.lambda_adv, params.cur_epoch, @@ -502,7 +502,7 @@ def train_one_epoch( scaler.scale(disc_loss).backward() scaler.step(optimizer_d) - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): g_weight = train_discriminator( params.lambda_adv, params.cur_epoch, @@ -846,7 +846,7 @@ def scan_pessimistic_batches_for_oom( ) = prepare_input(params, batch, device) try: # for discriminator - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): ( disc_stft_real_adv_loss, disc_stft_fake_adv_loss, @@ -876,7 +876,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): ( commit_loss, gen_stft_adv_loss, @@ -1102,7 +1102,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/libritts/TTS/vits/train.py b/egs/libritts/TTS/vits/train.py index 6803d6eb2..447fbcf5d 100755 --- a/egs/libritts/TTS/vits/train.py +++ b/egs/libritts/TTS/vits/train.py @@ -32,7 +32,7 @@ from lhotse.cut import Cut from lhotse.features.io import KaldiReader from lhotse.utils import fix_random_seed from tokenizer import Tokenizer -from torch.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -456,7 +456,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): # forward discriminator loss_d, stats_d = model( text=tokens, @@ -475,7 +475,7 @@ def train_one_epoch( scaler.scale(loss_d).backward() scaler.step(optimizer_d) - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): # forward generator loss_g, stats_g = model( text=tokens, @@ -748,7 +748,7 @@ def scan_pessimistic_batches_for_oom( ) = prepare_input(batch, tokenizer, device, train_speaker_map) try: # for discriminator - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): loss_d, stats_d = model( text=tokens, text_lengths=tokens_lens, @@ -762,7 +762,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): loss_g, stats_g = model( text=tokens, text_lengths=tokens_lens, @@ -922,7 +922,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index a25cc8723..853042413 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -17,7 +17,7 @@ from lhotse.utils import fix_random_seed from model import fix_len_compatibility from models.matcha_tts import MatchaTTS from tokenizer import Tokenizer -from torch.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -474,7 +474,7 @@ def train_one_epoch( tokens_lens, ) = prepare_input(batch, tokenizer, device, params) try: - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): losses = get_losses( { "x": tokens, @@ -645,7 +645,7 @@ def run(rank, world_size, args): valid_cuts = ljspeech.valid_cuts() valid_dl = ljspeech.valid_dataloaders(valid_cuts) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index e9994319a..184ae79af 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -30,7 +30,7 @@ import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed from tokenizer import Tokenizer -from torch.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -396,7 +396,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): # forward discriminator loss_d, stats_d = model( text=tokens, @@ -414,7 +414,7 @@ def train_one_epoch( scaler.scale(loss_d).backward() scaler.step(optimizer_d) - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): # forward generator loss_g, stats_g = model( text=tokens, @@ -673,7 +673,7 @@ def scan_pessimistic_batches_for_oom( ) try: # for discriminator - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): loss_d, stats_d = model( text=tokens, text_lengths=tokens_lens, @@ -686,7 +686,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): loss_g, stats_g = model( text=tokens, text_lengths=tokens_lens, @@ -838,7 +838,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/ljspeech/TTS/vits/utils.py b/egs/ljspeech/TTS/vits/utils.py index d51ff5f5c..6a067f596 100644 --- a/egs/ljspeech/TTS/vits/utils.py +++ b/egs/ljspeech/TTS/vits/utils.py @@ -23,7 +23,7 @@ import torch import torch.distributed as dist import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index 6fd6d219b..a1fabf9ad 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -25,7 +25,7 @@ from loss import ( KLDivergenceLoss, MelSpectrogramLoss, ) -from torch.amp import autocast +from torch.cuda.amp import autocast from utils import get_segments AVAILABLE_GENERATERS = { @@ -410,7 +410,7 @@ class VITS(nn.Module): p = self.discriminator(speech_) # calculate losses - with autocast("cuda", enabled=False): + with autocast(enabled=False): if not return_sample: mel_loss = self.mel_loss(speech_hat_, speech_) else: @@ -518,7 +518,7 @@ class VITS(nn.Module): p = self.discriminator(speech_) # calculate losses - with autocast("cuda", enabled=False): + with autocast(enabled=False): real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) loss = real_loss + fake_loss diff --git a/egs/mdcc/ASR/zipformer/train.py b/egs/mdcc/ASR/zipformer/train.py index 22249286a..730db7718 100755 --- a/egs/mdcc/ASR/zipformer/train.py +++ b/egs/mdcc/ASR/zipformer/train.py @@ -68,7 +68,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -906,7 +906,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1197,7 +1197,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1298,7 +1298,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py index 916ada475..48468cfbd 100755 --- a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py @@ -66,7 +66,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -751,7 +751,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info, inf_flag = compute_loss( params=params, model=model, @@ -1012,7 +1012,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1115,7 +1115,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _, _ = compute_loss( params=params, model=model, diff --git a/egs/multi_zh-hans/ASR/whisper/train.py b/egs/multi_zh-hans/ASR/whisper/train.py index 1a11d01af..fe2d950c1 100755 --- a/egs/multi_zh-hans/ASR/whisper/train.py +++ b/egs/multi_zh-hans/ASR/whisper/train.py @@ -61,7 +61,7 @@ from lhotse.utils import fix_random_seed from multi_dataset import MultiDataset from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.functional import pad as pad_tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -566,7 +566,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -675,7 +675,7 @@ def train_one_epoch( ) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -913,7 +913,7 @@ def run(rank, world_size, args): valid_cuts = multi_dataset.dev_cuts() valid_dl = data_module.valid_dataloaders(valid_cuts) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/multi_zh-hans/ASR/zipformer/train.py b/egs/multi_zh-hans/ASR/zipformer/train.py index 047253d5b..3dbfc48eb 100755 --- a/egs/multi_zh-hans/ASR/zipformer/train.py +++ b/egs/multi_zh-hans/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -987,7 +987,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1278,7 +1278,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1378,7 +1378,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index 9e64defa3..04bb41214 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -75,7 +75,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -969,7 +969,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1269,7 +1269,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1369,7 +1369,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py b/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py index c01e4d336..072679cfc 100755 --- a/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py +++ b/egs/reazonspeech/ASR/zipformer/do_not_use_it_directly.py @@ -67,7 +67,7 @@ from model import Transducer from optim import Eden, ScaledAdam from tokenizer import Tokenizer from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer_for_ncnn_export_only import Zipformer @@ -822,7 +822,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1113,7 +1113,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1213,7 +1213,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/reazonspeech/ASR/zipformer/train.py b/egs/reazonspeech/ASR/zipformer/train.py index 8829a18ca..30bd3efba 100755 --- a/egs/reazonspeech/ASR/zipformer/train.py +++ b/egs/reazonspeech/ASR/zipformer/train.py @@ -74,7 +74,7 @@ from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from tokenizer import Tokenizer from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -945,7 +945,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1235,7 +1235,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1335,7 +1335,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py index 5de2cf2b0..5f224c984 100755 --- a/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py +++ b/egs/speech_llm/ASR_LLM/whisper_llm_zh/train.py @@ -451,7 +451,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -566,7 +566,7 @@ def train_one_epoch( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py index 1e55ada87..a9146a0fe 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py @@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -649,7 +649,7 @@ def train_one_epoch( params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -857,7 +857,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -957,7 +957,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/spgispeech/ASR/zipformer/train.py b/egs/spgispeech/ASR/zipformer/train.py index 319713b02..dfc21c968 100755 --- a/egs/spgispeech/ASR/zipformer/train.py +++ b/egs/spgispeech/ASR/zipformer/train.py @@ -74,7 +74,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -946,7 +946,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1217,7 +1217,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1317,7 +1317,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index c44e30b89..c0aedd725 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -69,7 +69,7 @@ from local.tokenize_with_bpe_model import tokenize_by_bpe_model from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -726,7 +726,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) # print(batch["supervisions"]) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -967,7 +967,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1039,7 +1039,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py index dd9576d99..2108266ec 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless7_bbpe/train.py @@ -64,7 +64,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -801,7 +801,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1101,7 +1101,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1201,7 +1201,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py index 179dcf14a..fc3e3b2d9 100755 --- a/egs/tedlium3/ASR/conformer_ctc2/train.py +++ b/egs/tedlium3/ASR/conformer_ctc2/train.py @@ -57,7 +57,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -710,7 +710,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -941,7 +941,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1011,7 +1011,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/tedlium3/ASR/zipformer/model.py b/egs/tedlium3/ASR/zipformer/model.py index 0d9b395ed..65b052ab9 100644 --- a/egs/tedlium3/ASR/zipformer/model.py +++ b/egs/tedlium3/ASR/zipformer/model.py @@ -173,7 +173,7 @@ class Transducer(nn.Module): # if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=lm.float(), am=am.float(), @@ -209,7 +209,7 @@ class Transducer(nn.Module): # prior to do_rnnt_pruning (this is an optimization for speed). logits = self.joiner(am_pruned, lm_pruned, project_input=False) - with torch.amp.autocast("cuda", enabled=False): + with torch.cuda.amp.autocast(enabled=False): pruned_loss = k2.rnnt_loss_pruned( logits=logits.float(), symbols=y_padded, diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index ffe876863..14a44efb3 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -73,7 +73,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -911,7 +911,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1160,7 +1160,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1260,7 +1260,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 6249640d4..4686de169 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -31,7 +31,7 @@ import torch.nn as nn from lhotse.cut import Cut from lhotse.utils import fix_random_seed from tokenizer import Tokenizer -from torch.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.utils.tensorboard import SummaryWriter @@ -448,7 +448,7 @@ def train_one_epoch( loss_info["samples"] = batch_size try: - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): # forward discriminator loss_d, stats_d = model( text=tokens, @@ -467,7 +467,7 @@ def train_one_epoch( scaler.scale(loss_d).backward() scaler.step(optimizer_d) - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): # forward generator loss_g, stats_g = model( text=tokens, @@ -740,7 +740,7 @@ def scan_pessimistic_batches_for_oom( ) = prepare_input(batch, tokenizer, device, speaker_map) try: # for discriminator - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): loss_d, stats_d = model( text=tokens, text_lengths=tokens_lens, @@ -754,7 +754,7 @@ def scan_pessimistic_batches_for_oom( optimizer_d.zero_grad() loss_d.backward() # for generator - with autocast("cuda", enabled=params.use_fp16): + with autocast(enabled=params.use_fp16): loss_g, stats_g = model( text=tokens, text_lengths=tokens_lens, @@ -910,7 +910,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py index 2fd6f6478..c34f1593d 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -52,7 +52,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -718,7 +718,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -907,7 +907,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1005,7 +1005,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index c90f03f08..49977e01b 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -101,7 +101,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -687,7 +687,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -921,7 +921,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1019,7 +1019,7 @@ def scan_pessimistic_batches_for_oom( # warmup = 0.0 is so that the derivs for the pruned loss stay zero # (i.e. are not remembered by the decaying-average in adam), because # we want to avoid these params being subject to shrinkage in adam. - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 7b05eca97..931e699d9 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -81,7 +81,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -796,7 +796,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1056,7 +1056,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1158,7 +1158,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py index c46a4d84c..4e55fd6a8 100644 --- a/egs/wenetspeech/ASR/whisper/train.py +++ b/egs/wenetspeech/ASR/whisper/train.py @@ -61,7 +61,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.functional import pad as pad_tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -513,7 +513,7 @@ def compute_validation_loss( tot_loss = MetricsTracker() for batch_idx, batch in enumerate(valid_dl): - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -621,7 +621,7 @@ def train_one_epoch( f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}" ) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, tokenizer=tokenizer, @@ -843,7 +843,7 @@ def run(rank, world_size, args): train_dl = wenetspeech.train_dataloaders(train_cuts) valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts()) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/wenetspeech/ASR/zipformer/train.py b/egs/wenetspeech/ASR/zipformer/train.py index b6d55447f..25b16f632 100755 --- a/egs/wenetspeech/ASR/zipformer/train.py +++ b/egs/wenetspeech/ASR/zipformer/train.py @@ -71,7 +71,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -910,7 +910,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1201,7 +1201,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1302,7 +1302,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index 00db4309d..d19172b38 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -82,7 +82,7 @@ from lhotse.cut import Cut, CutSet from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from train import ( @@ -414,7 +414,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -703,7 +703,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py index 4dc30ad89..40960c2ae 100755 --- a/egs/wenetspeech/KWS/zipformer/train.py +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -73,7 +73,7 @@ from optim import Eden, ScaledAdam from scaling import ScheduledFloat from subsampling import Conv2dSubsampling from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 @@ -967,7 +967,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1252,7 +1252,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1353,7 +1353,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py index 1c6972e93..e9ec548f3 100755 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -65,7 +65,7 @@ from lhotse.utils import fix_random_seed from optim import Eden, ScaledAdam from tokenizer import TextTokenCollater, get_text_token_collater from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from tts_datamodule import TtsDataModule @@ -764,7 +764,7 @@ def train_one_epoch( batch_size = len(batch["text"]) try: - with torch.amp.autocast("cuda", dtype=dtype, enabled=enabled): + with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): _, loss, loss_info = compute_loss( params=params, model=model, @@ -897,7 +897,7 @@ def train_one_epoch( # Calculate validation loss in Rank 0 model.eval() logging.info("Computing validation loss") - with torch.amp.autocast("cuda", dtype=dtype): + with torch.cuda.amp.autocast(dtype=dtype): valid_info = compute_validation_loss( params=params, model=model, @@ -1102,9 +1102,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler( - "cuda", enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 - ) + scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1198,7 +1196,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", dtype=dtype): + with torch.cuda.amp.autocast(dtype=dtype): _, loss, _ = compute_loss( params=params, model=model, diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py index 5c3000a57..a6fa46b17 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py @@ -68,7 +68,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -814,7 +814,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1072,7 +1072,7 @@ def run(rank, world_size, args): warmup=0.0 if params.start_epoch == 1 else 1.0, ) - scaler = GradScaler("cuda", enabled=params.use_fp16) + scaler = GradScaler(enabled=params.use_fp16) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1141,7 +1141,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py index a1b3be246..dd72551d9 100755 --- a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -67,7 +67,7 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -785,7 +785,7 @@ def train_one_epoch( batch_size = len(batch["supervisions"]["text"]) try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, model=model, @@ -1074,7 +1074,7 @@ def run(rank, world_size, args): params=params, ) - scaler = GradScaler("cuda", enabled=params.use_fp16, init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1174,7 +1174,7 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, model=model, diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index b3a0fb865..d31ce1301 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn from lhotse.dataset.sampling.base import CutSampler from torch import Tensor -from torch.amp import GradScaler +from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 257cdb09a..0178b80bf 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -401,7 +401,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -470,7 +470,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py index 6faa63484..c36abfcdf 100644 --- a/icefall/transformer_lm/train.py +++ b/icefall/transformer_lm/train.py @@ -341,7 +341,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): x, y, sentence_lengths = batch - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, @@ -403,7 +403,7 @@ def train_one_epoch( params.batch_idx_train += 1 x, y, sentence_lengths = batch batch_size = x.size(0) - with torch.amp.autocast("cuda", enabled=params.use_fp16): + with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( model=model, x=x, From 92ed1708c0786e31f59628408869078b330faffa Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 18 Dec 2024 16:50:14 +0800 Subject: [PATCH 50/59] Add torch 1.13 and 2.0 to CI tests (#1840) --- .github/scripts/docker/generate_build_matrix.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/scripts/docker/generate_build_matrix.py b/.github/scripts/docker/generate_build_matrix.py index 9c53a38df..c5a1a54cb 100755 --- a/.github/scripts/docker/generate_build_matrix.py +++ b/.github/scripts/docker/generate_build_matrix.py @@ -45,13 +45,13 @@ def get_torchaudio_version(torch_version): def get_matrix(): k2_version = "1.24.4.dev20241029" kaldifeat_version = "1.25.5.dev20241029" - version = "20241029" + version = "20241218" # torchaudio 2.5.0 does not support python 3.13 python_version = ["3.8", "3.9", "3.10", "3.11", "3.12"] torch_version = [] - # torch_version += ["1.13.0", "1.13.1"] - # torch_version += ["2.0.0", "2.0.1"] + torch_version += ["1.13.0", "1.13.1"] + torch_version += ["2.0.0", "2.0.1"] # torch_version += ["2.1.0", "2.1.1", "2.1.2"] # torch_version += ["2.2.0", "2.2.1", "2.2.2"] # Test only torch >= 2.3.0 @@ -59,6 +59,7 @@ def get_matrix(): torch_version += ["2.4.0"] torch_version += ["2.4.1"] torch_version += ["2.5.0"] + torch_version += ["2.5.1"] matrix = [] for p in python_version: @@ -79,8 +80,12 @@ def get_matrix(): # torch>=2.5 requires python 3.10 continue - k2_version_2 = k2_version - kaldifeat_version_2 = kaldifeat_version + if t == "2.5.1": + k2_version_2 = "1.24.4.dev20241122" + kaldifeat_version_2 = "1.25.5.dev20241126" + else: + k2_version_2 = k2_version + kaldifeat_version_2 = kaldifeat_version matrix.append( { From ad966fb81d76c9b6780cac6844d9c4aa1782a46b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 19 Dec 2024 15:19:41 +0800 Subject: [PATCH 51/59] Minor fixes to the onnx inference script for ljspeech matcha-tts. (#1838) --- .github/scripts/ljspeech/TTS/run-matcha.sh | 20 +++++++++++++------- egs/ljspeech/TTS/matcha/export_onnx.py | 2 +- egs/ljspeech/TTS/matcha/onnx_pretrained.py | 19 ++++++++++++++----- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 0876cb47f..352d685a0 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -57,6 +57,7 @@ function infer() { curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 ./matcha/infer.py \ + --num-buckets 2 \ --epoch 1 \ --exp-dir ./matcha/exp \ --tokens data/tokens.txt \ @@ -97,19 +98,23 @@ function export_onnx() { python3 ./matcha/export_onnx_hifigan.py else curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx fi ls -lh *.onnx - python3 ./matcha/onnx_pretrained.py \ - --acoustic-model ./model-steps-6.onnx \ - --vocoder ./hifigan_v1.onnx \ - --tokens ./data/tokens.txt \ - --input-text "how are you doing?" \ - --output-wav /icefall/generated-matcha-tts-steps-6-v1.wav + for v in v1 v2 v3; do + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_$v.onnx \ + --tokens ./data/tokens.txt \ + --input-text "how are you doing?" \ + --output-wav /icefall/generated-matcha-tts-steps-6-$v.wav + done ls -lh /icefall/*.wav - soxi /icefall/generated-matcha-tts-steps-6-v1.wav + soxi /icefall/generated-matcha-tts-steps-6-*.wav } prepare_data @@ -118,3 +123,4 @@ infer export_onnx rm -rfv generator_v* matcha/exp +git checkout . diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index 487ea2995..623517431 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -163,7 +163,7 @@ def main(): (x, x_lengths, temperature, length_scale), filename, opset_version=opset_version, - input_names=["x", "x_length", "temperature", "length_scale"], + input_names=["x", "x_length", "noise_scale", "length_scale"], output_names=["mel"], dynamic_axes={ "x": {0: "N", 1: "L"}, diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 4eff9a084..6d92b16eb 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -89,6 +89,7 @@ class OnnxHifiGANModel: self.model.get_inputs()[0].name: x.numpy(), }, )[0] + # audio: (batch_size, num_samples) return torch.from_numpy(audio) @@ -97,19 +98,24 @@ class OnnxModel: def __init__( self, filename: str, + tokens: str, ): session_opts = ort.SessionOptions() session_opts.inter_op_num_threads = 1 session_opts.intra_op_num_threads = 2 self.session_opts = session_opts - self.tokenizer = Tokenizer("./data/tokens.txt") + self.tokenizer = Tokenizer(tokens) self.model = ort.InferenceSession( filename, sess_options=self.session_opts, providers=["CPUExecutionProvider"], ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + for i in self.model.get_inputs(): print(i) @@ -138,6 +144,7 @@ class OnnxModel: self.model.get_inputs()[3].name: length_scale.numpy(), }, )[0] + # mel: (batch_size, feat_dim, num_frames) return torch.from_numpy(mel) @@ -147,7 +154,7 @@ def main(): params = get_parser().parse_args() logging.info(vars(params)) - model = OnnxModel(params.acoustic_model) + model = OnnxModel(params.acoustic_model, params.tokens) vocoder = OnnxHifiGANModel(params.vocoder) text = params.input_text x = model.tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) @@ -164,15 +171,17 @@ def main(): print("audio", audio.shape) # (1, 1, num_samples) audio = audio.squeeze() + sample_rate = model.sample_rate + t = (end_t - start_t).total_seconds() t2 = (end_t2 - start_t2).total_seconds() - rtf_am = t * 22050 / audio.shape[-1] - rtf_vocoder = t2 * 22050 / audio.shape[-1] + rtf_am = t * sample_rate / audio.shape[-1] + rtf_vocoder = t2 * sample_rate / audio.shape[-1] print("RTF for acoustic model ", rtf_am) print("RTF for vocoder", rtf_vocoder) # skip denoiser - sf.write(params.output_wav, audio, 22050, "PCM_16") + sf.write(params.output_wav, audio, sample_rate, "PCM_16") logging.info(f"Saved to {params.output_wav}") From 57e9f2a8db43eaa62d8701ef456f8323e9bcb8ff Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Mon, 30 Dec 2024 15:27:05 +0800 Subject: [PATCH 52/59] Add the "rms-sort" diagnostics (#1851) --- icefall/diagnostics.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 37872f233..e5eaba619 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -63,12 +63,22 @@ def get_tensor_stats( "rms" -> square before summing, we'll take sqrt later "value" -> just sum x itself "max", "min" -> take the maximum or minimum [over all other dims but dim] instead of summing + "rms-sort" -> this is a bit different than the others, it's based on computing the + rms over the specified dim and returning percentiles of the result (11 of them). Returns: stats: a Tensor of shape (x.shape[dim],). count: an integer saying how many items were counted in each element of stats. """ + if stats_type == "rms-sort": + rms = (x**2).mean(dim=dim).sqrt() + rms = rms.flatten() + rms = rms.sort()[0] + rms = rms[(torch.arange(11) * rms.numel() // 10).clamp(max=rms.numel() - 1)] + count = 1.0 + return rms, count + count = x.numel() // x.shape[dim] if stats_type == "eigs": @@ -164,7 +174,17 @@ class TensorDiagnostic(object): for dim in range(ndim): this_dim_stats = self.stats[dim] if ndim > 1: - stats_types = ["abs", "max", "min", "positive", "value", "rms"] + # rms-sort is different from the others, it's based on summing over just this + # dim, then sorting and returning the percentiles. + stats_types = [ + "abs", + "max", + "min", + "positive", + "value", + "rms", + "rms-sort", + ] if x.shape[dim] <= self.opts.max_eig_dim: stats_types.append("eigs") else: From 48088cb80703e3e5a98a1cab7558669b77875233 Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Mon, 30 Dec 2024 15:30:02 +0800 Subject: [PATCH 53/59] Refactor optimizer (#1837) * Print indexes of largest grad --- egs/librispeech/ASR/zipformer/optim.py | 445 ++++++++++++------------- 1 file changed, 212 insertions(+), 233 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index 8434fab13..8a1764651 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -121,6 +121,139 @@ class BatchedOptimizer(Optimizer): p.copy_(stacked_params[i]) +def basic_step(group, p, state, grad): + # computes basic Adam update using beta2 (dividing by gradient stddev) only. no momentum yet. + lr = group["lr"] + if p.numel() == p.shape[0]: + lr = lr * group["scalar_lr_scale"] + beta2 = group["betas"][1] + eps = group["eps"] + # p shape: (batch_size,) or (batch_size, 1, [1,..]) + try: + exp_avg_sq = state[ + "exp_avg_sq" + ] # shape: (batch_size,) or (batch_size, 1, [1,..]) + except KeyError: + exp_avg_sq = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["exp_avg_sq"] = exp_avg_sq + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # bias_correction2 is like in Adam. + # slower update at the start will help stability anyway. + bias_correction2 = 1 - beta2 ** (state["step"] + 1) + if bias_correction2 < 0.99: + # note: not in-place. + exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) + denom = exp_avg_sq.sqrt().add_(eps) + + return -lr * grad / denom + + +def scaling_step(group, p, state, grad): + delta = basic_step(group, p, state, grad) + if p.numel() == p.shape[0]: + return delta # there is no scaling for scalar parameters. (p.shape[0] is the batch of parameters.) + + step = state["step"] + size_update_period = group["size_update_period"] + + try: + param_rms = state["param_rms"] + scale_grads = state["scale_grads"] + scale_exp_avg_sq = state["scale_exp_avg_sq"] + except KeyError: + # we know p.ndim > 1 because we'd have returned above if not, so don't worry + # about the speial case of dim=[] that pytorch treats inconsistently. + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = param_rms.to(torch.float) + scale_exp_avg_sq = torch.zeros_like(param_rms) + scale_grads = torch.zeros( + size_update_period, *param_rms.shape, dtype=torch.float, device=p.device + ) + state["param_rms"] = param_rms + state["scale_grads"] = scale_grads + state["scale_exp_avg_sq"] = scale_exp_avg_sq + + # on every step, update the gradient w.r.t. the scale of the parameter, we + # store these as a batch and periodically update the size (for speed only, to + # avoid too many operations). + scale_grads[step % size_update_period] = (p * grad).sum( + dim=list(range(1, p.ndim)), keepdim=True + ) + + # periodically recompute the value of param_rms. + if step % size_update_period == size_update_period - 1: + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()) + + param_min_rms = group["param_min_rms"] + + # scale the step size by param_rms. This is the most important "scaling" part of + # ScaledAdam + delta *= param_rms.clamp(min=param_min_rms) + + if step % size_update_period == size_update_period - 1 and step > 0: + # This block updates the size of parameter by adding a step ("delta") value in + # the direction of either shrinking or growing it. + beta2 = group["betas"][1] + size_lr = group["lr"] * group["scalar_lr_scale"] + param_max_rms = group["param_max_rms"] + eps = group["eps"] + batch_size = p.shape[0] + # correct beta2 for the size update period: we will have + # faster decay at this level. + beta2_corr = beta2**size_update_period + scale_exp_avg_sq.mul_(beta2_corr).add_( + (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + alpha=1 - beta2_corr, + ) # shape is (batch_size, 1, 1, ...) + + # The 1st time we reach here is when size_step == 1. + size_step = (step + 1) // size_update_period + bias_correction2 = 1 - beta2_corr**size_step + + denom = scale_exp_avg_sq.sqrt() + eps + + scale_step = ( + -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + ) + + is_too_small = param_rms < param_min_rms + + # when the param gets too small, just don't shrink it any further. + scale_step.masked_fill_(is_too_small, 0.0) + + # The following may help prevent instability: don't allow the scale step to be too large in + # either direction. + scale_step.clamp_(min=-0.1, max=0.1) + + # and ensure the parameter rms after update never exceeds param_max_rms. + # We have to look at the trained model for parameters at or around the + # param_max_rms, because sometimes they can indicate a problem with the + # topology or settings. + scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + + delta.add_(p * scale_step) + + return delta + + +def momentum_step(group, p, state, grad): + delta = scaling_step(group, p, state, grad) + beta1 = group["betas"][0] + try: + stored_delta = state["delta"] + except KeyError: + stored_delta = torch.zeros(*p.shape, device=p.device, dtype=torch.float) + state["delta"] = stored_delta + stored_delta.mul_(beta1) + stored_delta.add_(delta, alpha=(1 - beta1)) + # we don't bother doing the "bias correction" part of Adam for beta1 because this is just + # an edge effect that affects the first 10 or so batches; and the effect of not doing it + # is just to do a slower update for the first few batches, which will help stability. + return stored_delta + + class ScaledAdam(BatchedOptimizer): """ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update @@ -352,58 +485,26 @@ class ScaledAdam(BatchedOptimizer): raise RuntimeError( "ScaledAdam optimizer does not support sparse gradients" ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - self._step_one_batch(group, p, state, clipping_scale) + try: + cur_step = state["step"] + except KeyError: + state["step"] = 0 + cur_step = 0 + + grad = ( + p.grad if clipping_scale == 1.0 else p.grad.mul_(clipping_scale) + ) + p += momentum_step(group, p.detach(), state, grad) + + if p.numel() == p.shape[0]: # scalar parameter + scalar_max = group["scalar_max"] + p.clamp_(min=-scalar_max, max=scalar_max) + + state["step"] = cur_step + 1 return loss - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] ) -> float: @@ -484,7 +585,7 @@ class ScaledAdam(BatchedOptimizer): ) first_state["num_clipped"] = 0 quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.warn( + logging.warning( f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" ) @@ -499,8 +600,8 @@ class ScaledAdam(BatchedOptimizer): ans = 0.0 if ans < 1.0: first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( + if ans < 0.5: + logging.warning( f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" ) if self.show_dominant_parameters: @@ -508,6 +609,7 @@ class ScaledAdam(BatchedOptimizer): self._show_gradient_dominating_parameter( tuples, tot_sumsq, group["scalar_lr_scale"] ) + self._show_param_with_unusual_grad(tuples) if ans == 0.0: for (p, state, param_names) in tuples: @@ -515,6 +617,55 @@ class ScaledAdam(BatchedOptimizer): return ans + def _show_param_with_unusual_grad( + self, + tuples: List[Tuple[Tensor, dict, List[str]]], + ): + """ + Print information about parameter which has the largest ratio of grad-on-this-batch + divided by normal grad size. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + """ + largest_ratio = 0.0 + largest_name = "" + # ratios_names is a list of 3-tuples: (grad_ratio, param_name, tensor) + ratios_names = [] + for (p, state, batch_param_names) in tuples: + dims = list(range(1, p.ndim)) + + def mean(x): + # workaround for bad interface of torch's "mean" for when dims is the empty list. + if len(dims) > 0: + return x.mean(dim=dims) + else: + return x + + grad_ratio = ( + (mean(p.grad**2) / state["exp_avg_sq"].mean(dim=dims)) + .sqrt() + .to("cpu") + ) + + ratios_names += zip( + grad_ratio.tolist(), batch_param_names, p.grad.unbind(dim=0) + ) + + ratios_names = sorted(ratios_names, reverse=True) + ratios_names = ratios_names[:10] + ratios_names = [ + (ratio, name, largest_index(tensor)) + for (ratio, name, tensor) in ratios_names + ] + + logging.warning( + f"Parameters with most larger-than-usual grads, with ratios, are: {ratios_names}" + ) + def _show_gradient_dominating_parameter( self, tuples: List[Tuple[Tensor, dict, List[str]]], @@ -572,7 +723,7 @@ class ScaledAdam(BatchedOptimizer): dominant_rms, dominant_grad, ) = sorted_by_proportion[dominant_param_name] - logging.warn( + logging.warning( f"Parameter dominating tot_sumsq {dominant_param_name}" f" with proportion {dominant_proportion:.2f}," f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" @@ -581,182 +732,11 @@ class ScaledAdam(BatchedOptimizer): f" orig_rms_sq={(dominant_rms**2).item():.3e}" ) - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - grad = p.grad - if clipping_scale != 1.0: - grad *= clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - - # and ensure the parameter rms after update never exceeds param_max_rms. - # We have to look at the trained model for parameters at or around the - # param_max_rms, because sometimes they can indicate a problem with the - # topology or settings. - scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) - - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) +def largest_index(x: Tensor): + x = x.contiguous() + argmax = x.abs().argmax().item() + return [(argmax // x.stride(i)) % x.size(i) for i in range(x.ndim)] class LRScheduler(object): @@ -787,9 +767,9 @@ class LRScheduler(object): is not the optimizer. """ return { - # the user might try to override the base_lr, so don't include this in the state. - # previously they were included. - # "base_lrs": self.base_lrs, + # the user might try to override the base_lr, so don't include this in the state. + # previously they were included. + # "base_lrs": self.base_lrs, "epoch": self.epoch, "batch": self.batch, } @@ -807,7 +787,6 @@ class LRScheduler(object): self.__dict__.update(state_dict) self.base_lrs = base_lrs - def get_last_lr(self) -> List[float]: """Return last computed learning rate by current scheduler. Will be a list of float.""" return self._last_lr @@ -853,7 +832,7 @@ class LRScheduler(object): def print_lr(self, is_verbose, group, lr): """Display the current learning rate.""" if is_verbose: - logging.warn( + logging.warning( f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" f" of group {group} to {lr:.4e}." ) @@ -1184,7 +1163,7 @@ def _test_scaled_adam(hidden_dim: int): if iter == 0: optim = Eve(m.parameters(), lr=0.003) elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + optim = ScaledAdam(m.named_parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) start = timeit.default_timer() From a2b0f6057c41a1f0eff64ed95356d14751a0c791 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 31 Dec 2024 07:41:44 +0800 Subject: [PATCH 54/59] Small fix (#1853) --- egs/wenetspeech4tts/TTS/prepare.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh index 54e140dbb..3d7ffadb1 100755 --- a/egs/wenetspeech4tts/TTS/prepare.sh +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -10,9 +10,9 @@ stop_stage=4 dl_dir=$PWD/download -dataset_parts="Premium" # Basic for all 10k hours data, Premium for about 10% of the data +dataset_parts="Premium" # Basic for all 7226 hours data, Premium for 945 hours subset. -text_extractor="pypinyin_initials_finals" # default is espeak for English +text_extractor="pypinyin_initials_finals" # default is espeak for English audio_extractor="Encodec" # or Fbank audio_feats_dir=data/tokenized @@ -63,7 +63,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then --audio-extractor ${audio_extractor} \ --batch-duration 2500 --prefix "wenetspeech4tts" \ --src-dir "data/manifests" \ - --split 100 \ + --split 100 \ --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir} fi From df46a3eaf94c0089c485eb23b0e08bf2b63cb53a Mon Sep 17 00:00:00 2001 From: Han Zhu <1106766460@qq.com> Date: Tue, 31 Dec 2024 16:52:06 +0800 Subject: [PATCH 55/59] Warn instead of raising exceptions in inf-check (#1852) --- icefall/hooks.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/icefall/hooks.py b/icefall/hooks.py index 1c5bd2ae6..83f2750fa 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -40,9 +40,7 @@ def register_inf_check_hooks(model: nn.Module) -> None: def forward_hook(_module, _input, _output, _name=name): if isinstance(_output, Tensor): if not torch.isfinite(_output.to(torch.float32).sum()): - raise ValueError( - f"The sum of {_name}.output is not finite: {_output}" - ) + logging.warning(f"The sum of {_name}.output is not finite") elif isinstance(_output, tuple): for i, o in enumerate(_output): if isinstance(o, tuple): @@ -50,9 +48,7 @@ def register_inf_check_hooks(model: nn.Module) -> None: if not isinstance(o, Tensor): continue if not torch.isfinite(o.to(torch.float32).sum()): - raise ValueError( - f"The sum of {_name}.output[{i}] is not finite: {_output}" - ) + logging.warning(f"The sum of {_name}.output[{i}] is not finite") # default param _name is a way to capture the current value of the variable "name". def backward_hook(_module, _input, _output, _name=name): From bfffda5afb74b193068078c9b51db380ec005afe Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 31 Dec 2024 17:17:05 +0800 Subject: [PATCH 56/59] Add MatchaTTS for the Chinese dataset Baker (#1849) --- .github/scripts/baker_zh/TTS/run-matcha.sh | 167 ++++ .../scripts/docker/generate_build_matrix.py | 18 +- .github/scripts/ljspeech/TTS/run-matcha.sh | 2 +- .github/workflows/baker_zh.yml | 152 ++++ egs/baker_zh/TTS/.gitignore | 6 + egs/baker_zh/TTS/README.md | 146 ++++ egs/baker_zh/TTS/local/audio.py | 1 + .../TTS/local/compute_fbank_baker_zh.py | 110 +++ .../TTS/local/compute_fbank_statistics.py | 1 + .../TTS/local/convert_text_to_tokens.py | 121 +++ egs/baker_zh/TTS/local/fbank.py | 1 + egs/baker_zh/TTS/local/generate_tokens.py | 85 +++ egs/baker_zh/TTS/local/validate_manifest.py | 70 ++ egs/baker_zh/TTS/matcha/__init__.py | 0 egs/baker_zh/TTS/matcha/audio.py | 1 + egs/baker_zh/TTS/matcha/export_onnx.py | 207 +++++ .../TTS/matcha/export_onnx_hifigan.py | 1 + egs/baker_zh/TTS/matcha/fbank.py | 1 + egs/baker_zh/TTS/matcha/generate_lexicon.py | 42 + egs/baker_zh/TTS/matcha/hifigan | 1 + egs/baker_zh/TTS/matcha/infer.py | 342 +++++++++ egs/baker_zh/TTS/matcha/model.py | 1 + egs/baker_zh/TTS/matcha/models | 1 + egs/baker_zh/TTS/matcha/monotonic_align | 1 + egs/baker_zh/TTS/matcha/onnx_pretrained.py | 316 ++++++++ egs/baker_zh/TTS/matcha/tokenizer.py | 119 +++ egs/baker_zh/TTS/matcha/train.py | 717 ++++++++++++++++++ egs/baker_zh/TTS/matcha/tts_datamodule.py | 340 +++++++++ egs/baker_zh/TTS/matcha/utils.py | 1 + egs/baker_zh/TTS/prepare.sh | 151 ++++ egs/baker_zh/TTS/shared | 1 + egs/ljspeech/TTS/README.md | 2 +- egs/ljspeech/TTS/matcha/export_onnx.py | 11 +- egs/ljspeech/TTS/matcha/onnx_pretrained.py | 4 +- 34 files changed, 3128 insertions(+), 12 deletions(-) create mode 100755 .github/scripts/baker_zh/TTS/run-matcha.sh create mode 100644 .github/workflows/baker_zh.yml create mode 100644 egs/baker_zh/TTS/.gitignore create mode 100644 egs/baker_zh/TTS/README.md create mode 120000 egs/baker_zh/TTS/local/audio.py create mode 100755 egs/baker_zh/TTS/local/compute_fbank_baker_zh.py create mode 120000 egs/baker_zh/TTS/local/compute_fbank_statistics.py create mode 100755 egs/baker_zh/TTS/local/convert_text_to_tokens.py create mode 120000 egs/baker_zh/TTS/local/fbank.py create mode 100755 egs/baker_zh/TTS/local/generate_tokens.py create mode 100755 egs/baker_zh/TTS/local/validate_manifest.py create mode 100644 egs/baker_zh/TTS/matcha/__init__.py create mode 120000 egs/baker_zh/TTS/matcha/audio.py create mode 100755 egs/baker_zh/TTS/matcha/export_onnx.py create mode 120000 egs/baker_zh/TTS/matcha/export_onnx_hifigan.py create mode 120000 egs/baker_zh/TTS/matcha/fbank.py create mode 100755 egs/baker_zh/TTS/matcha/generate_lexicon.py create mode 120000 egs/baker_zh/TTS/matcha/hifigan create mode 100755 egs/baker_zh/TTS/matcha/infer.py create mode 120000 egs/baker_zh/TTS/matcha/model.py create mode 120000 egs/baker_zh/TTS/matcha/models create mode 120000 egs/baker_zh/TTS/matcha/monotonic_align create mode 100755 egs/baker_zh/TTS/matcha/onnx_pretrained.py create mode 100644 egs/baker_zh/TTS/matcha/tokenizer.py create mode 100755 egs/baker_zh/TTS/matcha/train.py create mode 100644 egs/baker_zh/TTS/matcha/tts_datamodule.py create mode 120000 egs/baker_zh/TTS/matcha/utils.py create mode 100755 egs/baker_zh/TTS/prepare.sh create mode 120000 egs/baker_zh/TTS/shared diff --git a/.github/scripts/baker_zh/TTS/run-matcha.sh b/.github/scripts/baker_zh/TTS/run-matcha.sh new file mode 100755 index 000000000..150f023ae --- /dev/null +++ b/.github/scripts/baker_zh/TTS/run-matcha.sh @@ -0,0 +1,167 @@ +#!/usr/bin/env bash + +set -ex + +apt-get update +apt-get install -y sox + +python3 -m pip install numba conformer==0.3.2 diffusers librosa +python3 -m pip install jieba + + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/baker_zh/TTS + +sed -i.bak s/600/8/g ./prepare.sh +sed -i.bak s/"first 100"/"first 3"/g ./prepare.sh +sed -i.bak s/500/5/g ./prepare.sh +git diff + +function prepare_data() { + # We have created a subset of the data for testing + # + mkdir -p download + pushd download + wget -q https://huggingface.co/csukuangfj/tmp-files/resolve/main/BZNSYP-samples.tar.bz2 + tar xvf BZNSYP-samples.tar.bz2 + mv BZNSYP-samples BZNSYP + rm BZNSYP-samples.tar.bz2 + popd + + ./prepare.sh + tree . +} + +function train() { + pushd ./matcha + sed -i.bak s/1500/3/g ./train.py + git diff . + popd + + ./matcha/train.py \ + --exp-dir matcha/exp \ + --num-epochs 1 \ + --save-every-n 1 \ + --num-buckets 2 \ + --tokens data/tokens.txt \ + --max-duration 20 + + ls -lh matcha/exp +} + +function infer() { + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + + ./matcha/infer.py \ + --num-buckets 2 \ + --epoch 1 \ + --exp-dir ./matcha/exp \ + --tokens data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --vocoder ./generator_v2 \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav + + ls -lh *.wav + soxi ./generated.wav + rm -v ./generated.wav + rm -v generator_v2 +} + +function export_onnx() { + pushd matcha/exp + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/epoch-2000.pt + popd + + pushd data/fbank + rm -v *.json + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-baker-matcha-zh-2024-12-27/resolve/main/cmvn.json + popd + + ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json + + ls -lh *.onnx + + if false; then + # The CI machine does not have enough memory to run it + # + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + python3 ./matcha/export_onnx_hifigan.py + else + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v1.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v2.onnx + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/exp/hifigan_v3.onnx + fi + + ls -lh *.onnx + + python3 ./matcha/generate_lexicon.py + + for v in v1 v2 v3; do + python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-6.onnx \ + --vocoder ./hifigan_$v.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav /icefall/generated-matcha-tts-steps-6-$v.wav + done + + ls -lh /icefall/*.wav + soxi /icefall/generated-matcha-tts-steps-6-*.wav + cp ./model-steps-*.onnx /icefall + + d=matcha-icefall-zh-baker + mkdir $d + cp -v data/tokens.txt $d + cp -v lexicon.txt $d + cp model-steps-3.onnx $d + pushd $d + curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2 + tar xvf dict.tar.bz2 + rm dict.tar.bz2 + + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst + curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst + +cat >README.md < + +The training command is given below: +```bash +python3 ./matcha/train.py \ + --exp-dir ./matcha/exp-1/ \ + --num-workers 4 \ + --world-size 1 \ + --num-epochs 2000 \ + --max-duration 1200 \ + --bucketing-sampler 1 \ + --start-epoch 1 +``` + +To inference, use: + +```bash +# Download Hifigan vocoder. We use Hifigan v2 below. You can select from v1, v2, or v3 + +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 + +python3 ./matcha/infer.py \ + --epoch 2000 \ + --exp-dir ./matcha/exp-1 \ + --vocoder ./generator_v2 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav +``` + +```bash +soxi ./generated.wav +``` + +prints: +``` +Input File : './generated.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:17.31 = 381696 samples ~ 1298.29 CDDA sectors +File Size : 763k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +https://github.com/user-attachments/assets/88d4e88f-ebc4-4f32-b216-16d46b966024 + + +To export the checkpoint to onnx: +```bash +python3 ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-1 \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json +``` + +The above command generates the following files: +``` +-rw-r--r-- 1 kuangfangjun root 72M Dec 27 18:53 model-steps-2.onnx +-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-3.onnx +-rw-r--r-- 1 kuangfangjun root 73M Dec 27 18:54 model-steps-4.onnx +-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:55 model-steps-5.onnx +-rw-r--r-- 1 kuangfangjun root 74M Dec 27 18:57 model-steps-6.onnx +``` + +where the 2 in `model-steps-2.onnx` means it uses 2 steps for the ODE solver. + +**HINT**: If you get the following error while running `export_onnx.py`: + +``` +torch.onnx.errors.UnsupportedOperatorError: Exporting the operator +'aten::scaled_dot_product_attention' to ONNX opset version 14 is not supported. +``` + +please use `torch>=2.2.0`. + +To export the Hifigan vocoder to onnx, please use: + +```bash +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v2 +wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v3 + +python3 ./matcha/export_onnx_hifigan.py +``` + +The above command generates 3 files: + + - hifigan_v1.onnx + - hifigan_v2.onnx + - hifigan_v3.onnx + +**HINT**: You can download pre-exported hifigan ONNX models from + + +To use the generated onnx files to generate speech from text, please run: + +```bash + +# First, generate ./lexicon.txt +python3 ./matcha/generate_lexicon.py + +python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-4.onnx \ + --vocoder ./hifigan_v2.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "在一个阳光明媚的夏天,小马、小羊和小狗它们一块儿在广阔的草地上,嬉戏玩耍,这时小猴来了,还带着它心爱的足球活蹦乱跳地跑前、跑后教小马、小羊、小狗踢足球。" \ + --output-wav ./1.wav +``` + +```bash +soxi ./1.wav + +Input File : './1.wav' +Channels : 1 +Sample Rate : 22050 +Precision : 16-bit +Duration : 00:00:16.37 = 360960 samples ~ 1227.76 CDDA sectors +File Size : 722k +Bit Rate : 353k +Sample Encoding: 16-bit Signed Integer PCM +``` + +https://github.com/user-attachments/assets/578d04bb-fee8-47e5-9984-a868dcce610e + diff --git a/egs/baker_zh/TTS/local/audio.py b/egs/baker_zh/TTS/local/audio.py new file mode 120000 index 000000000..b70d91c92 --- /dev/null +++ b/egs/baker_zh/TTS/local/audio.py @@ -0,0 +1 @@ +../matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py new file mode 100755 index 000000000..0720158f2 --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_fbank_baker_zh.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the baker-zh dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, LilcomChunkyWriter, load_manifest +from lhotse.audio import RecordingSet +from lhotse.supervision import SupervisionSet + +from icefall.utils import get_executor + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=4, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + return parser + + +def compute_fbank_baker_zh(num_jobs: int): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + if num_jobs < 1: + num_jobs = os.cpu_count() + + logging.info(f"num_jobs: {num_jobs}") + logging.info(f"src_dir: {src_dir}") + logging.info(f"output_dir: {output_dir}") + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=22050, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + + prefix = "baker_zh" + suffix = "jsonl.gz" + + extractor = MatchaFbank(config) + + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{prefix}_cuts.{suffix}" + logging.info(f"Processing {cuts_filename}") + cut_set = load_manifest(src_dir / cuts_filename).resample(22050) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + # Torch's multithreaded behavior needs to be disabled or + # it wastes a lot of CPU and slow things down. + # Do this outside of main() in case it needs to take effect + # even when we are not invoking the main (e.g. when spawning subprocesses). + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_parser().parse_args() + compute_fbank_baker_zh(args.num_jobs) diff --git a/egs/baker_zh/TTS/local/compute_fbank_statistics.py b/egs/baker_zh/TTS/local/compute_fbank_statistics.py new file mode 120000 index 000000000..fd1d8b52e --- /dev/null +++ b/egs/baker_zh/TTS/local/compute_fbank_statistics.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/compute_fbank_statistics.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/convert_text_to_tokens.py b/egs/baker_zh/TTS/local/convert_text_to_tokens.py new file mode 100755 index 000000000..bf59cb466 --- /dev/null +++ b/egs/baker_zh/TTS/local/convert_text_to_tokens.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +import argparse +import re +from typing import List + +import jieba +from lhotse import load_manifest +from pypinyin import Style, lazy_pinyin, load_phrases_dict + +load_phrases_dict( + { + "行长": [["hang2"], ["zhang3"]], + "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], + } +) + +whiter_space_re = re.compile(r"\s+") + +punctuations_re = [ + (re.compile(x[0], re.IGNORECASE), x[1]) + for x in [ + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("“", '"'), + ("”", '"'), + ("‘", "'"), + ("’", "'"), + (":", ":"), + ("、", ","), + ("B", "逼"), + ("P", "批"), + ] +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--in-file", + type=str, + required=True, + help="Input cutset.", + ) + + parser.add_argument( + "--out-file", + type=str, + required=True, + help="Output cutset.", + ) + + return parser + + +def normalize_white_spaces(text): + return whiter_space_re.sub(" ", text) + + +def normalize_punctuations(text): + for regex, replacement in punctuations_re: + text = re.sub(regex, replacement, text) + return text + + +def split_text(text: str) -> List[str]: + """ + Example input: '你好呀,You are 一个好人。 去银行存钱?How about you?' + Example output: ['你好', '呀', ',', 'you are', '一个', '好人', '.', '去', '银行', '存钱', '?', 'how about you', '?'] + """ + text = text.lower() + text = normalize_white_spaces(text) + text = normalize_punctuations(text) + ans = [] + + for seg in jieba.cut(text): + if seg in ",.!?:\"'": + ans.append(seg) + elif seg == " " and len(ans) > 0: + if ord("a") <= ord(ans[-1][-1]) <= ord("z"): + ans[-1] += seg + elif ord("a") <= ord(seg[0]) <= ord("z"): + if len(ans) == 0: + ans.append(seg) + continue + + if ans[-1][-1] == " ": + ans[-1] += seg + continue + + ans.append(seg) + else: + ans.append(seg) + + ans = [s.strip() for s in ans] + return ans + + +def main(): + args = get_parser().parse_args() + cuts = load_manifest(args.in_file) + for c in cuts: + assert len(c.supervisions) == 1, (len(c.supervisions), c.supervisions) + text = c.supervisions[0].normalized_text + + text_list = split_text(text) + tokens = lazy_pinyin(text_list, style=Style.TONE3, tone_sandhi=True) + + c.tokens = tokens + + cuts.to_file(args.out_file) + + print(f"saved to {args.out_file}") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/fbank.py b/egs/baker_zh/TTS/local/fbank.py new file mode 120000 index 000000000..5bcf1fde5 --- /dev/null +++ b/egs/baker_zh/TTS/local/fbank.py @@ -0,0 +1 @@ +../matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/local/generate_tokens.py b/egs/baker_zh/TTS/local/generate_tokens.py new file mode 100755 index 000000000..b2abe1a71 --- /dev/null +++ b/egs/baker_zh/TTS/local/generate_tokens.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 + +""" +This file generates the file tokens.txt. + +Usage: + +python3 ./local/generate_tokens.py > data/tokens.txt +""" + + +import argparse +from typing import List + +import jieba +from pypinyin import Style, lazy_pinyin, pinyin_dict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to to save tokens.txt.", + ) + + return parser + + +def generate_token_list() -> List[str]: + token_set = set() + + word_dict = pinyin_dict.pinyin_dict + i = 0 + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + t = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] + token_set.add(t) + + no_digit = set() + for t in token_set: + if t[-1] not in "1234": + no_digit.add(t) + else: + no_digit.add(t[:-1]) + + no_digit.add("dei") + no_digit.add("tou") + no_digit.add("dia") + + for t in no_digit: + token_set.add(t) + for i in range(1, 5): + token_set.add(f"{t}{i}") + + ans = list(token_set) + ans.sort() + + punctuations = list(",.!?:\"'") + ans = punctuations + ans + + # use ID 0 for blank + # Use ID 1 of _ for padding + ans.insert(0, " ") + ans.insert(1, "_") # + + return ans + + +def main(): + args = get_parser().parse_args() + token_list = generate_token_list() + with open(args.tokens, "w", encoding="utf-8") as f: + for indx, token in enumerate(token_list): + f.write(f"{token} {indx}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/local/validate_manifest.py b/egs/baker_zh/TTS/local/validate_manifest.py new file mode 100755 index 000000000..4e31028f7 --- /dev/null +++ b/egs/baker_zh/TTS/local/validate_manifest.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/spectrogram/baker_zh_cuts_all.jsonl.gz + +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset.speech_synthesis import validate_for_tts + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + manifest = args.manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet), type(cut_set) + + validate_for_tts(cut_set) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/baker_zh/TTS/matcha/__init__.py b/egs/baker_zh/TTS/matcha/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/baker_zh/TTS/matcha/audio.py b/egs/baker_zh/TTS/matcha/audio.py new file mode 120000 index 000000000..62d3959d6 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/audio.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/audio.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/export_onnx.py b/egs/baker_zh/TTS/matcha/export_onnx.py new file mode 100755 index 000000000..28efbfe61 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/export_onnx.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This script exports a Matcha-TTS model to ONNX. +Note that the model outputs fbank. You need to use a vocoder to convert +it to audio. See also ./export_onnx_hifigan.py + +python3 ./matcha/export_onnx.py \ + --exp-dir ./matcha/exp-1 \ + --epoch 2000 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json + +""" + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +import onnx +import torch +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=2000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp-new-3", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model, num_steps: int = 5): + super().__init__() + self.model = model + self.num_steps = num_steps + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + noise_scale: torch.Tensor, + length_scale: torch.Tensor, + ) -> torch.Tensor: + """ + Args: : + x: (batch_size, num_tokens), torch.int64 + x_lengths: (batch_size,), torch.int64 + noise_scale: (1,), torch.float32 + length_scale (1,), torch.float32 + Returns: + audio: (batch_size, num_samples) + + """ + mel = self.model.synthesise( + x=x, + x_lengths=x_lengths, + n_timesteps=self.num_steps, + temperature=noise_scale, + length_scale=length_scale, + )["mel"] + # mel: (batch_size, feat_dim, num_frames) + + return mel + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + for num_steps in [2, 3, 4, 5, 6]: + logging.info(f"num_steps: {num_steps}") + wrapper = ModelWrapper(model, num_steps=num_steps) + wrapper.eval() + + # Use a large value so the rotary position embedding in the text + # encoder has a large initial length + x = torch.ones(1, 1000, dtype=torch.int64) + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + noise_scale = torch.tensor([1.0]) + length_scale = torch.tensor([1.0]) + + opset_version = 14 + filename = f"model-steps-{num_steps}.onnx" + torch.onnx.export( + wrapper, + (x, x_lengths, noise_scale, length_scale), + filename, + opset_version=opset_version, + input_names=["x", "x_length", "noise_scale", "length_scale"], + output_names=["mel"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_length": {0: "N"}, + "mel": {0: "N", 2: "L"}, + }, + ) + + meta_data = { + "model_type": "matcha-tts", + "language": "Chinese", + "has_espeak": 0, + "n_speakers": 1, + "jieba": 1, + "sample_rate": 22050, + "version": 1, + "pad_id": params.pad_id, + "model_author": "icefall", + "maintainer": "k2-fsa", + "dataset": "baker-zh", + "use_eos_bos": 0, + "dataset_url": "https://www.data-baker.com/open_source.html", + "dataset_comment": "The dataset is for non-commercial use only.", + "num_ode_steps": num_steps, + } + add_meta_data(filename=filename, meta_data=meta_data) + print(meta_data) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py new file mode 120000 index 000000000..d0b8af15b --- /dev/null +++ b/egs/baker_zh/TTS/matcha/export_onnx_hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/export_onnx_hifigan.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/fbank.py b/egs/baker_zh/TTS/matcha/fbank.py new file mode 120000 index 000000000..3cfb7fe3f --- /dev/null +++ b/egs/baker_zh/TTS/matcha/fbank.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/fbank.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/generate_lexicon.py b/egs/baker_zh/TTS/matcha/generate_lexicon.py new file mode 100755 index 000000000..f26f28e91 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/generate_lexicon.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 + +import jieba +from pypinyin import Style, lazy_pinyin, load_phrases_dict, phrases_dict, pinyin_dict +from tokenizer import Tokenizer + +load_phrases_dict( + { + "行长": [["hang2"], ["zhang3"]], + "银行行长": [["yin2"], ["hang2"], ["hang2"], ["zhang3"]], + } +) + + +def main(): + filename = "lexicon.txt" + tokens = "./data/tokens.txt" + tokenizer = Tokenizer(tokens) + + word_dict = pinyin_dict.pinyin_dict + phrases = phrases_dict.phrases_dict + + i = 0 + with open(filename, "w", encoding="utf-8") as f: + for key in word_dict: + if not (0x4E00 <= key <= 0x9FFF): + continue + + w = chr(key) + tokens = lazy_pinyin(w, style=Style.TONE3, tone_sandhi=True)[0] + + f.write(f"{w} {tokens}\n") + + for key in phrases: + tokens = lazy_pinyin(key, style=Style.TONE3, tone_sandhi=True) + tokens = " ".join(tokens) + + f.write(f"{key} {tokens}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/matcha/hifigan b/egs/baker_zh/TTS/matcha/hifigan new file mode 120000 index 000000000..c0a91072c --- /dev/null +++ b/egs/baker_zh/TTS/matcha/hifigan @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/hifigan \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/infer.py b/egs/baker_zh/TTS/matcha/infer.py new file mode 100755 index 000000000..b90c2fdbd --- /dev/null +++ b/egs/baker_zh/TTS/matcha/infer.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +""" +python3 ./matcha/infer.py \ + --epoch 2000 \ + --exp-dir ./matcha/exp-1 \ + --vocoder ./generator_v2 \ + --tokens ./data/tokens.txt \ + --cmvn ./data/fbank/cmvn.json \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./generated.wav +""" + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +import torch.nn as nn +from hifigan.config import v1, v2, v3 +from hifigan.denoiser import Denoiser +from hifigan.models import Generator as HiFiGAN +from local.convert_text_to_tokens import split_text +from pypinyin import Style, lazy_pinyin +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import BakerZhTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + # The following arguments are used for inference on single text + parser.add_argument( + "--input-text", + type=str, + required=False, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=False, + help="The filename of the wave to save the generated speech", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for baker_zh)", + ) + + return parser + + +def load_vocoder(checkpoint_path: Path) -> nn.Module: + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform( + mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module +) -> torch.Tensor: + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.squeeze() + + +def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: + text = split_text(text) + tokens = lazy_pinyin(text, style=Style.TONE3, tone_sandhi=True) + + x = tokenizer.texts_to_token_ids([tokens]) + x = torch.tensor(x, dtype=torch.long, device=device) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesize( + model: nn.Module, + tokenizer: Tokenizer, + n_timesteps: int, + text: str, + length_scale: float, + temperature: float, + device: str = "cpu", + spks=None, +) -> dict: + text_processed = process_text(text=text, tokenizer=tokenizer, device=device) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + denoiser: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + texts = [c.supervisions[0].normalized_text for c in batch["cut"]] + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + for i in range(batch_size): + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=texts[i], + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", + data=output["waveform"], + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", + data=audio[i].numpy(), + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.inference_mode() +def main(): + parser = get_parser() + BakerZhTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + + # we need cut ids to organize tts results. + args.return_cuts = True + baker_zh = BakerZhTtsDataModule(args) + + test_cuts = baker_zh.test_cuts() + test_dl = baker_zh.test_dataloaders(test_cuts) + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + vocoder.to(device) + + denoiser = Denoiser(vocoder, mode="zeros") + denoiser.to(device) + + if params.input_text is not None and params.output_wav is not None: + logging.info("Synthesizing a single text") + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=params.input_text, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) + else: + logging.info("Decoding the test set") + infer_dataset( + dl=test_dl, + params=params, + model=model, + vocoder=vocoder, + denoiser=denoiser, + tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/baker_zh/TTS/matcha/model.py b/egs/baker_zh/TTS/matcha/model.py new file mode 120000 index 000000000..8a1b812a9 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/model.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/model.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/models b/egs/baker_zh/TTS/matcha/models new file mode 120000 index 000000000..09a862665 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/models @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/models \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/monotonic_align b/egs/baker_zh/TTS/matcha/monotonic_align new file mode 120000 index 000000000..d0a0dd6b5 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/monotonic_align \ No newline at end of file diff --git a/egs/baker_zh/TTS/matcha/onnx_pretrained.py b/egs/baker_zh/TTS/matcha/onnx_pretrained.py new file mode 100755 index 000000000..f6b7f7cae --- /dev/null +++ b/egs/baker_zh/TTS/matcha/onnx_pretrained.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +python3 ./matcha/onnx_pretrained.py \ + --acoustic-model ./model-steps-4.onnx \ + --vocoder ./hifigan_v2.onnx \ + --tokens ./data/tokens.txt \ + --lexicon ./lexicon.txt \ + --input-text "当夜幕降临,星光点点,伴随着微风拂面,我在静谧中感受着时光的流转,思念如涟漪荡漾,梦境如画卷展开,我与自然融为一体,沉静在这片宁静的美丽之中,感受着生命的奇迹与温柔。" \ + --output-wav ./b.wav +""" + +import argparse +import datetime as dt +import logging +import re +from typing import Dict, List + +import jieba +import onnxruntime as ort +import soundfile as sf +import torch +from infer import load_vocoder +from utils import intersperse + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--acoustic-model", + type=str, + required=True, + help="Path to the acoustic model", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--lexicon", + type=str, + required=True, + help="Path to the lexicon.txt", + ) + + parser.add_argument( + "--vocoder", + type=str, + required=True, + help="Path to the vocoder", + ) + + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + + return parser + + +class OnnxHifiGANModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 3, x.shape + assert x.shape[0] == 1, x.shape + + audio = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + }, + )[0] + # audio: (batch_size, num_samples) + + return torch.from_numpy(audio) + + +class OnnxModel: + def __init__( + self, + filename: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 2 + + self.session_opts = session_opts + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + metadata = self.model.get_modelmeta().custom_metadata_map + self.sample_rate = int(metadata["sample_rate"]) + + for i in self.model.get_inputs(): + print(i) + + print("-----") + + for i in self.model.get_outputs(): + print(i) + + def __call__(self, x: torch.tensor): + assert x.ndim == 2, x.shape + assert x.shape[0] == 1, x.shape + + x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) + print("x_lengths", x_lengths) + print("x", x.shape) + + noise_scale = torch.tensor([1.0], dtype=torch.float32) + length_scale = torch.tensor([1.0], dtype=torch.float32) + + mel = self.model.run( + [self.model.get_outputs()[0].name], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lengths.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: length_scale.numpy(), + }, + )[0] + # mel: (batch_size, feat_dim, num_frames) + + return torch.from_numpy(mel) + + +def read_tokens(filename: str) -> Dict[str, int]: + token2id = dict() + with open(filename, encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + idx = int(info[0]) + else: + token, idx = info[0], int(info[1]) + assert token not in token2id, token + token2id[token] = idx + return token2id + + +def read_lexicon(filename: str) -> Dict[str, List[str]]: + word2token = dict() + with open(filename, encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + w = info[0] + tokens = info[1:] + word2token[w] = tokens + return word2token + + +def convert_word_to_tokens(word2tokens: Dict[str, List[str]], word: str) -> List[str]: + if word in word2tokens: + return word2tokens[word] + + if len(word) == 1: + return [] + + ans = [] + for w in word: + t = convert_word_to_tokens(word2tokens, w) + ans.extend(t) + return ans + + +def normalize_text(text): + whiter_space_re = re.compile(r"\s+") + + punctuations_re = [ + (re.compile(x[0], re.IGNORECASE), x[1]) + for x in [ + (",", ","), + ("。", "."), + ("!", "!"), + ("?", "?"), + ("“", '"'), + ("”", '"'), + ("‘", "'"), + ("’", "'"), + (":", ":"), + ("、", ","), + ] + ] + + for regex, replacement in punctuations_re: + text = re.sub(regex, replacement, text) + return text + + +@torch.no_grad() +def main(): + params = get_parser().parse_args() + logging.info(vars(params)) + token2id = read_tokens(params.tokens) + word2tokens = read_lexicon(params.lexicon) + + text = normalize_text(params.input_text) + seg = jieba.cut(text) + tokens = [] + for s in seg: + if s in token2id: + tokens.append(s) + continue + + t = convert_word_to_tokens(word2tokens, s) + if t: + tokens.extend(t) + + model = OnnxModel(params.acoustic_model) + vocoder = OnnxHifiGANModel(params.vocoder) + + x = [] + for t in tokens: + if t in token2id: + x.append(token2id[t]) + + x = intersperse(x, item=token2id["_"]) + + x = torch.tensor(x, dtype=torch.int64).unsqueeze(0) + + start_t = dt.datetime.now() + mel = model(x) + end_t = dt.datetime.now() + + start_t2 = dt.datetime.now() + audio = vocoder(mel) + end_t2 = dt.datetime.now() + + print("audio", audio.shape) # (1, 1, num_samples) + audio = audio.squeeze() + + sample_rate = model.sample_rate + + t = (end_t - start_t).total_seconds() + t2 = (end_t2 - start_t2).total_seconds() + rtf_am = t * sample_rate / audio.shape[-1] + rtf_vocoder = t2 * sample_rate / audio.shape[-1] + print("RTF for acoustic model ", rtf_am) + print("RTF for vocoder", rtf_vocoder) + + # skip denoiser + sf.write(params.output_wav, audio, sample_rate, "PCM_16") + logging.info(f"Saved to {params.output_wav}") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() + +""" + +|HifiGAN |RTF |#Parameters (M)| +|----------|-----|---------------| +|v1 |0.818| 13.926 | +|v2 |0.101| 0.925 | +|v3 |0.118| 1.462 | + +|Num steps|Acoustic Model RTF| +|---------|------------------| +| 2 | 0.039 | +| 3 | 0.047 | +| 4 | 0.071 | +| 5 | 0.076 | +| 6 | 0.103 | + +""" diff --git a/egs/baker_zh/TTS/matcha/tokenizer.py b/egs/baker_zh/TTS/matcha/tokenizer.py new file mode 100644 index 000000000..dda82c29d --- /dev/null +++ b/egs/baker_zh/TTS/matcha/tokenizer.py @@ -0,0 +1,119 @@ +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import logging +from typing import Dict, List + +import tacotron_cleaner.cleaners + +try: + from piper_phonemize import phonemize_espeak +except Exception as ex: + raise RuntimeError( + f"{ex}\nPlease run\n" + "pip install piper_phonemize -f https://k2-fsa.github.io/icefall/piper_phonemize.html" + ) + +from utils import intersperse + + +# This tokenizer supports both English and Chinese. +# We assume you have used +# ../local/convert_text_to_tokens.py +# to process your text +class Tokenizer(object): + def __init__(self, tokens: str): + """ + Args: + tokens: the file that maps tokens to ids + """ + # Parse token file + self.token2id: Dict[str, int] = {} + with open(tokens, "r", encoding="utf-8") as f: + for line in f.readlines(): + info = line.rstrip().split() + if len(info) == 1: + # case of space + token = " " + id = int(info[0]) + else: + token, id = info[0], int(info[1]) + assert token not in self.token2id, token + self.token2id[token] = id + + # Refer to https://github.com/rhasspy/piper/blob/master/TRAINING.md + self.pad_id = self.token2id["_"] # padding + self.space_id = self.token2id[" "] # word separator (whitespace) + + self.vocab_size = len(self.token2id) + + def texts_to_token_ids( + self, + sentence_list: List[List[str]], + intersperse_blank: bool = True, + lang: str = "en-us", + ) -> List[List[int]]: + """ + Args: + sentence_list: + A list of sentences. + intersperse_blank: + Whether to intersperse blanks in the token sequence. + lang: + Language argument passed to phonemize_espeak(). + + Returns: + Return a list of token id list [utterance][token_id] + """ + token_ids_list = [] + + for sentence in sentence_list: + tokens_list = [] + for word in sentence: + if word in self.token2id: + tokens_list.append(word) + continue + + tmp_tokens_list = phonemize_espeak(word, lang) + for t in tmp_tokens_list: + tokens_list.extend(t) + + token_ids = [] + for t in tokens_list: + if t not in self.token2id: + logging.warning(f"Skip OOV {t} {sentence}") + continue + + if t == " " and len(token_ids) > 0 and token_ids[-1] == self.space_id: + continue + + token_ids.append(self.token2id[t]) + + if intersperse_blank: + token_ids = intersperse(token_ids, self.pad_id) + + token_ids_list.append(token_ids) + + return token_ids_list + + +def test_tokenizer(): + import jieba + from pypinyin import Style, lazy_pinyin + + tokenizer = Tokenizer("data/tokens.txt") + text1 = "今天is Monday, tomorrow is 星期二" + text2 = "你好吗? 我很好, how about you?" + + text1 = list(jieba.cut(text1)) + text2 = list(jieba.cut(text2)) + tokens1 = lazy_pinyin(text1, style=Style.TONE3, tone_sandhi=True) + tokens2 = lazy_pinyin(text2, style=Style.TONE3, tone_sandhi=True) + print(tokens1) + print(tokens2) + + ids = tokenizer.texts_to_token_ids([tokens1, tokens2]) + print(ids) + + +if __name__ == "__main__": + test_tokenizer() diff --git a/egs/baker_zh/TTS/matcha/train.py b/egs/baker_zh/TTS/matcha/train.py new file mode 100755 index 000000000..ed2ba49b9 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/train.py @@ -0,0 +1,717 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + + +import argparse +import json +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Union + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.utils import fix_random_seed +from model import fix_len_compatibility +from models.matcha_tts import MatchaTTS +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import BakerZhTtsDataModule +from utils import MetricsTracker + +from icefall.checkpoint import load_checkpoint, save_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12335, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_data_statistics(): + return AttributeDict( + { + "mel_mean": 0, + "mel_std": 1, + } + ) + + +def _get_data_params() -> AttributeDict: + params = AttributeDict( + { + "name": "baker-zh", + "train_filelist_path": "./filelists/ljs_audio_text_train_filelist.txt", + "valid_filelist_path": "./filelists/ljs_audio_text_val_filelist.txt", + # "batch_size": 64, + # "num_workers": 1, + # "pin_memory": False, + "cleaners": ["english_cleaners2"], + "add_blank": True, + "n_spks": 1, + "n_fft": 1024, + "n_feats": 80, + "sampling_rate": 22050, + "hop_length": 256, + "win_length": 1024, + "f_min": 0, + "f_max": 8000, + "seed": 1234, + "load_durations": False, + "data_statistics": get_data_statistics(), + } + ) + return params + + +def _get_model_params() -> AttributeDict: + n_feats = 80 + filter_channels_dp = 256 + encoder_params_p_dropout = 0.1 + params = AttributeDict( + { + "n_spks": 1, # for baker-zh. + "spk_emb_dim": 64, + "n_feats": n_feats, + "out_size": None, # or use 172 + "prior_loss": True, + "use_precomputed_durations": False, + "data_statistics": get_data_statistics(), + "encoder": AttributeDict( + { + "encoder_type": "RoPE Encoder", # not used + "encoder_params": AttributeDict( + { + "n_feats": n_feats, + "n_channels": 192, + "filter_channels": 768, + "filter_channels_dp": filter_channels_dp, + "n_heads": 2, + "n_layers": 6, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + "spk_emb_dim": 64, + "n_spks": 1, + "prenet": True, + } + ), + "duration_predictor_params": AttributeDict( + { + "filter_channels_dp": filter_channels_dp, + "kernel_size": 3, + "p_dropout": encoder_params_p_dropout, + } + ), + } + ), + "decoder": AttributeDict( + { + "channels": [256, 256], + "dropout": 0.05, + "attention_head_dim": 64, + "n_blocks": 1, + "num_mid_blocks": 2, + "num_heads": 2, + "act_fn": "snakebeta", + } + ), + "cfm": AttributeDict( + { + "name": "CFM", + "solver": "euler", + "sigma_min": 1e-4, + } + ), + "optimizer": AttributeDict( + { + "lr": 1e-4, + "weight_decay": 0.0, + } + ), + } + ) + + return params + + +def get_params(): + params = AttributeDict( + { + "model_args": _get_model_params(), + "data_args": _get_data_params(), + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 10, + "valid_interval": 1500, + "env_info": get_env_info(), + } + ) + return params + + +def get_model(params): + m = MatchaTTS(**params.model_args) + return m + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device, params): + """Parse batch data""" + mel_mean = params.data_args.data_statistics.mel_mean + mel_std_inv = 1 / params.data_args.data_statistics.mel_std + for i in range(batch["features"].shape[0]): + n = batch["features_lens"][i] + batch["features"][i : i + 1, :n, :] = ( + batch["features"][i : i + 1, :n, :] - mel_mean + ) * mel_std_inv + batch["features"][i : i + 1, n:, :] = 0 + + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + + tokens = tokenizer.texts_to_token_ids(tokens, intersperse_blank=True) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + max_feature_length = fix_len_compatibility(features.shape[1]) + if max_feature_length > features.shape[1]: + pad = max_feature_length - features.shape[1] + features = torch.nn.functional.pad(features, (0, 0, 0, pad)) + + # features_lens[features_lens.argmax()] += pad + + return audio, audio_lens, features, features_lens.long(), tokens, tokens_lens.long() + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, + rank: int = 0, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + batch_size = len(batch["tokens"]) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + # summary stats + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer: Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + get_losses = model.module.get_losses if isinstance(model, DDP) else model.get_losses + + # used to track the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer=optimizer, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + # audio: (N, T), float32 + # features: (N, T, C), float32 + # audio_lens, (N,), int32 + # features_lens, (N,), int32 + # tokens: List[List[str]], len(tokens) == N + + batch_size = len(batch["tokens"]) + + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + ) = prepare_input(batch, tokenizer, device, params) + try: + with autocast(enabled=params.use_fp16): + losses = get_losses( + { + "x": tokens, + "x_lengths": tokens_lens, + "y": features.permute(0, 2, 1), + "y_lengths": features_lens, + "spks": None, # should change it for multi-speakers + "durations": None, + } + ) + + loss = sum(losses.values()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + s = 0 + + for key, value in losses.items(): + v = value.detach().item() + loss_info[key] = v * batch_size + s += v * batch_size + + loss_info["tot_loss"] = s + + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. + # The _growth_interval of the grad scaler is configurable, + # but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, " + f"batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if params.batch_idx_train % params.valid_interval == 1: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + rank=rank, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + "Maximum memory allocated so far is " + f"{torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["tot_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.pad_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + logging.info(params) + print(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of parameters: {num_param}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = torch.optim.Adam(model.parameters(), **params.model_args.optimizer) + + logging.info("About to create datamodule") + + baker_zh = BakerZhTtsDataModule(args) + + train_cuts = baker_zh.train_cuts() + train_dl = baker_zh.train_dataloaders(train_cuts) + + valid_cuts = baker_zh.valid_cuts() + valid_dl = baker_zh.valid_dataloaders(valid_cuts) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + fix_random_seed(params.seed + epoch - 1) + if "sampler" in train_dl: + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer=optimizer, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + BakerZhTtsDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/baker_zh/TTS/matcha/tts_datamodule.py b/egs/baker_zh/TTS/matcha/tts_datamodule.py new file mode 100644 index 000000000..d2bdfb96c --- /dev/null +++ b/egs/baker_zh/TTS/matcha/tts_datamodule.py @@ -0,0 +1,340 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2023 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from fbank import MatchaFbank, MatchaFbankConfig +from lhotse import CutSet, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class BakerZhTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=False, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + pin_memory=True, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create valid dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=True, + pin_memory=True, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = 22050 + config = MatchaFbankConfig( + n_fft=1024, + n_mels=80, + sampling_rate=sampling_rate, + hop_length=256, + win_length=1024, + f_min=0, + f_max=8000, + ) + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_train.jsonl.gz" + ) + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get validation cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "baker_zh_cuts_test.jsonl.gz" + ) diff --git a/egs/baker_zh/TTS/matcha/utils.py b/egs/baker_zh/TTS/matcha/utils.py new file mode 120000 index 000000000..ceaaea196 --- /dev/null +++ b/egs/baker_zh/TTS/matcha/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/matcha/utils.py \ No newline at end of file diff --git a/egs/baker_zh/TTS/prepare.sh b/egs/baker_zh/TTS/prepare.sh new file mode 100755 index 000000000..e15e3d850 --- /dev/null +++ b/egs/baker_zh/TTS/prepare.sh @@ -0,0 +1,151 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download +mkdir -p $dl_dir + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib (used by ./matcha)" + for recipe in matcha; do + if [ ! -d $recipe/monotonic_align/build ]; then + cd $recipe/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib for $recipe already built" + fi + done +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # The directory $dl_dir/BANSYP contains the following 3 directories + + # ls -lh $dl_dir/BZNSYP/ + # total 0 + # drwxr-xr-x 10002 kuangfangjun root 0 Jan 4 2019 PhoneLabeling + # drwxr-xr-x 3 kuangfangjun root 0 Jan 31 2019 ProsodyLabeling + # drwxr-xr-x 10003 kuangfangjun root 0 Aug 26 17:45 Wave + + # If you have trouble accessing huggingface.co, please use + # + # cd $dl_dir + # wget https://huggingface.co/openspeech/BZNSYP/resolve/main/BZNSYP.tar.bz2 + # tar xf BZNSYP.tar.bz2 + # cd .. + + # If you have pre-downloaded it to /path/to/BZNSYP, you can create a symlink + # + # ln -sfv /path/to/BZNSYP $dl_dir/BZNSYP + # + if [ ! -d $dl_dir/BZNSYP/Wave ]; then + lhotse download baker-zh $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare baker-zh manifest" + # We assume that you have downloaded the baker corpus + # to $dl_dir/BZNSYP + mkdir -p data/manifests + if [ ! -e data/manifests/.baker-zh.done ]; then + lhotse prepare baker-zh $dl_dir/BZNSYP data/manifests + touch data/manifests/.baker-zh.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Generate tokens.txt" + if [ ! -e data/tokens.txt ]; then + python3 ./local/generate_tokens.py --tokens data/tokens.txt + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Generate raw cutset" + if [ ! -e data/manifests/baker_zh_cuts_raw.jsonl.gz ]; then + lhotse cut simple \ + -r ./data/manifests/baker_zh_recordings_all.jsonl.gz \ + -s ./data/manifests/baker_zh_supervisions_all.jsonl.gz \ + ./data/manifests/baker_zh_cuts_raw.jsonl.gz + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Convert text to tokens" + if [ ! -e data/manifests/baker_zh_cuts.jsonl.gz ]; then + python3 ./local/convert_text_to_tokens.py \ + --in-file ./data/manifests/baker_zh_cuts_raw.jsonl.gz \ + --out-file ./data/manifests/baker_zh_cuts.jsonl.gz + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Generate fbank (used by ./matcha)" + mkdir -p data/fbank + if [ ! -e data/fbank/.baker-zh.done ]; then + ./local/compute_fbank_baker_zh.py + touch data/fbank/.baker-zh.done + fi + + if [ ! -e data/fbank/.baker-zh-validated.done ]; then + log "Validating data/fbank for baker-zh (used by ./matcha)" + python3 ./local/validate_manifest.py \ + data/fbank/baker_zh_cuts.jsonl.gz + touch data/fbank/.baker-zh-validated.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Split the baker-zh cuts into train, valid and test sets (used by ./matcha)" + if [ ! -e data/fbank/.baker_zh_split.done ]; then + lhotse subset --last 600 \ + data/fbank/baker_zh_cuts.jsonl.gz \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz + lhotse subset --first 100 \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz \ + data/fbank/baker_zh_cuts_valid.jsonl.gz + lhotse subset --last 500 \ + data/fbank/baker_zh_cuts_validtest.jsonl.gz \ + data/fbank/baker_zh_cuts_test.jsonl.gz + + rm data/fbank/baker_zh_cuts_validtest.jsonl.gz + + n=$(( $(gunzip -c data/fbank/baker_zh_cuts.jsonl.gz | wc -l) - 600 )) + + lhotse subset --first $n \ + data/fbank/baker_zh_cuts.jsonl.gz \ + data/fbank/baker_zh_cuts_train.jsonl.gz + + touch data/fbank/.baker_zh_split.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 6: Compute fbank mean and std (used by ./matcha)" + if [ ! -f ./data/fbank/cmvn.json ]; then + ./local/compute_fbank_statistics.py ./data/fbank/baker_zh_cuts_train.jsonl.gz ./data/fbank/cmvn.json + fi +fi diff --git a/egs/baker_zh/TTS/shared b/egs/baker_zh/TTS/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/baker_zh/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 39280437b..c9cfc22fd 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -166,7 +166,7 @@ To export the checkpoint to onnx: --tokens ./data/tokens.txt ``` -The above command generate the following files: +The above command generates the following files: - model-steps-2.onnx - model-steps-3.onnx diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index 623517431..39709cc36 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -93,14 +93,14 @@ class ModelWrapper(torch.nn.Module): self, x: torch.Tensor, x_lengths: torch.Tensor, - temperature: torch.Tensor, + noise_scale: torch.Tensor, length_scale: torch.Tensor, ) -> torch.Tensor: """ Args: : x: (batch_size, num_tokens), torch.int64 x_lengths: (batch_size,), torch.int64 - temperature: (1,), torch.float32 + noise_scale: (1,), torch.float32 length_scale (1,), torch.float32 Returns: audio: (batch_size, num_samples) @@ -110,7 +110,7 @@ class ModelWrapper(torch.nn.Module): x=x, x_lengths=x_lengths, n_timesteps=self.num_steps, - temperature=temperature, + temperature=noise_scale, length_scale=length_scale, )["mel"] # mel: (batch_size, feat_dim, num_frames) @@ -127,7 +127,6 @@ def main(): params.update(vars(args)) tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size params.model_args.n_vocab = params.vocab_size @@ -153,14 +152,14 @@ def main(): # encoder has a large initial length x = torch.ones(1, 1000, dtype=torch.int64) x_lengths = torch.tensor([x.shape[1]], dtype=torch.int64) - temperature = torch.tensor([1.0]) + noise_scale = torch.tensor([1.0]) length_scale = torch.tensor([1.0]) opset_version = 14 filename = f"model-steps-{num_steps}.onnx" torch.onnx.export( wrapper, - (x, x_lengths, temperature, length_scale), + (x, x_lengths, noise_scale, length_scale), filename, opset_version=opset_version, input_names=["x", "x_length", "noise_scale", "length_scale"], diff --git a/egs/ljspeech/TTS/matcha/onnx_pretrained.py b/egs/ljspeech/TTS/matcha/onnx_pretrained.py index 6d92b16eb..19e9b49cb 100755 --- a/egs/ljspeech/TTS/matcha/onnx_pretrained.py +++ b/egs/ljspeech/TTS/matcha/onnx_pretrained.py @@ -132,7 +132,7 @@ class OnnxModel: print("x_lengths", x_lengths) print("x", x.shape) - temperature = torch.tensor([1.0], dtype=torch.float32) + noise_scale = torch.tensor([1.0], dtype=torch.float32) length_scale = torch.tensor([1.0], dtype=torch.float32) mel = self.model.run( @@ -140,7 +140,7 @@ class OnnxModel: { self.model.get_inputs()[0].name: x.numpy(), self.model.get_inputs()[1].name: x_lengths.numpy(), - self.model.get_inputs()[2].name: temperature.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), self.model.get_inputs()[3].name: length_scale.numpy(), }, )[0] From 3b263539cd34fb14b53d72339bc7c095028f4578 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 2 Jan 2025 15:54:34 +0800 Subject: [PATCH 57/59] Publish MatchaTTS onnx models trained with LJSpeech to huggingface (#1854) --- .github/scripts/docker/Dockerfile | 2 +- .github/scripts/ljspeech/TTS/run-matcha.sh | 33 +++++++++- .github/workflows/ljspeech.yml | 74 +++++++++++++++++++++- egs/ljspeech/TTS/README.md | 9 +++ egs/ljspeech/TTS/matcha/export_onnx.py | 4 ++ 5 files changed, 118 insertions(+), 4 deletions(-) diff --git a/.github/scripts/docker/Dockerfile b/.github/scripts/docker/Dockerfile index 94e8d8e1e..cf0523401 100644 --- a/.github/scripts/docker/Dockerfile +++ b/.github/scripts/docker/Dockerfile @@ -49,7 +49,7 @@ RUN pip install --no-cache-dir \ kaldifst \ kaldilm \ librosa \ - matplotlib \ + "matplotlib<=3.9.4" \ multi_quantization \ numba \ "numpy<2.0" \ diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 954dd5bd8..bfb37fb6d 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -77,7 +77,7 @@ function export_onnx() { popd pushd data/fbank - rm -v *.json + rm -fv *.json curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-ljspeech-matcha-en-2024-10-28/resolve/main/data/cmvn.json popd @@ -115,6 +115,37 @@ function export_onnx() { ls -lh /icefall/*.wav soxi /icefall/generated-matcha-tts-steps-6-*.wav + + cp ./model-steps-*.onnx /icefall + + d=matcha-icefall-en_US-ljspeech + mkdir $d + cp -v data/tokens.txt $d + cp model-steps-3.onnx $d + pushd $d + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2 + tar xf espeak-ng-data.tar.bz2 + rm espeak-ng-data.tar.bz2 + +cat >README.md <=2.2.0`. + To export the Hifigan vocoder to onnx, please use: diff --git a/egs/ljspeech/TTS/matcha/export_onnx.py b/egs/ljspeech/TTS/matcha/export_onnx.py index 39709cc36..3c653fbf1 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx.py +++ b/egs/ljspeech/TTS/matcha/export_onnx.py @@ -176,12 +176,16 @@ def main(): "language": "English", "voice": "en-us", "has_espeak": 1, + "jieba": 0, "n_speakers": 1, "sample_rate": 22050, "version": 1, + "pad_id": tokenizer.pad_id, "model_author": "icefall", "maintainer": "k2-fsa", + "use_eos_bos": 1, "dataset": "LJ Speech", + "dataset_url": "https://keithito.com/LJ-Speech-Dataset/", "num_ode_steps": num_steps, } add_meta_data(filename=filename, meta_data=meta_data) From 3b6d54007b7b9d0f2ee28ced3d91caed773ae3c1 Mon Sep 17 00:00:00 2001 From: Seonuk Kim <49300300+snkii@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:17:02 +0900 Subject: [PATCH 58/59] Update conformer.py (#1857) * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension --- egs/librispeech/ASR/conformer_ctc/conformer.py | 2 +- egs/librispeech/ASR/conformer_ctc2/conformer.py | 2 +- egs/librispeech/ASR/conformer_mmi/conformer.py | 2 +- egs/librispeech/ASR/pruned2_knowledge/conformer.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py | 2 +- egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py | 2 +- egs/librispeech/ASR/streaming_conformer_ctc/conformer.py | 2 +- egs/librispeech/ASR/transducer_stateless/conformer.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index a1cfe6e75..3ac60e32f 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -32,7 +32,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers num_decoder_layers (int): number of decoder layers dropout (float): dropout rate diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index 09f1eb000..02ea80a46 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -42,7 +42,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers num_decoder_layers (int): number of decoder layers dropout (float): dropout rate diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 53e48eb13..cffe3df28 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -33,7 +33,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers num_decoder_layers (int): number of decoder layers dropout (float): dropout rate diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index de367c234..69cc59756 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -42,7 +42,7 @@ class Conformer(EncoderInterface): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index ab46e233b..85e61ebab 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -42,7 +42,7 @@ class Conformer(EncoderInterface): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 8bbceec61..968ea4150 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -42,7 +42,7 @@ class Conformer(EncoderInterface): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 0667e7f61..8c1529500 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -42,7 +42,7 @@ class Conformer(EncoderInterface): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension, also the output dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate layer_dropout (float): layer-dropout rate. diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index 0b982f4bf..72842cc28 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -69,7 +69,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers num_decoder_layers (int): number of decoder layers dropout (float): dropout rate diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 90b722bde..9b11df673 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -35,7 +35,7 @@ class Conformer(Transformer): subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) d_model (int): attention dimension nhead (int): number of head - dim_feedforward (int): feedforward dimention + dim_feedforward (int): feedforward dimension num_encoder_layers (int): number of encoder layers dropout (float): dropout rate cnn_module_kernel (int): Kernel size of convolution module From 8d602806c3b141a5a75ac7d165292dc3d13d19b8 Mon Sep 17 00:00:00 2001 From: Seonuk Kim <49300300+snkii@users.noreply.github.com> Date: Mon, 6 Jan 2025 18:31:13 +0900 Subject: [PATCH 59/59] Update conformer.py (#1859) * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py feedforward dimention -> feedforward dimension * Update conformer.py Swich -? Swish --- egs/librispeech/ASR/conformer_ctc/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 3ac60e32f..ea793ce2f 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -902,7 +902,7 @@ class Swish(torch.nn.Module): """Construct an Swish object.""" def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" + """Return Swish activation function.""" return x * torch.sigmoid(x)