Merge branch 'k2-fsa:master' into k2ssl

This commit is contained in:
Yifan Yang 2024-03-10 13:10:36 +08:00 committed by GitHub
commit 660f647886
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
181 changed files with 16099 additions and 413 deletions

View File

@ -16,7 +16,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps: steps:
# refer to https://github.com/actions/checkout # refer to https://github.com/actions/checkout

View File

@ -14,7 +14,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest] os: [ubuntu-latest]
image: ["torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"] image: ["torch2.2.1-cuda12.1", "torch2.2.1-cuda11.8", "torch2.2.0-cuda12.1", "torch2.2.0-cuda11.8", "torch2.1.0-cuda12.1", "torch2.1.0-cuda11.8", "torch2.0.0-cuda11.7", "torch1.13.0-cuda11.6", "torch1.12.1-cuda11.3", "torch1.9.0-cuda10.2"]
steps: steps:
# refer to https://github.com/actions/checkout # refer to https://github.com/actions/checkout
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@ -49,7 +49,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
# Click issue fixed in https://github.com/psf/black/pull/2966 # Click issue fixed in https://github.com/psf/black/pull/2966
- name: Run flake8 - name: Run flake8
@ -67,3 +67,9 @@ jobs:
working-directory: ${{github.workspace}} working-directory: ${{github.workspace}}
run: | run: |
black --check --diff . black --check --diff .
- name: Run isort
shell: bash
working-directory: ${{github.workspace}}
run: |
isort --check --diff .

View File

@ -26,7 +26,7 @@ repos:
# E121,E123,E126,E226,E24,E704,W503,W504 # E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.11.5 rev: 5.10.1
hooks: hooks:
- id: isort - id: isort
args: ["--profile=black"] args: ["--profile=black"]

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.7 # python 3.7
ARG K2_VERSION="1.24.4.dev20240211+cuda11.3.torch1.12.1" ARG K2_VERSION="1.24.4.dev20240223+cuda11.3.torch1.12.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.3.torch1.12.1" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.3.torch1.12.1"
ARG TORCHAUDIO_VERSION="0.12.1+cu113" ARG TORCHAUDIO_VERSION="0.12.1+cu113"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.9 # python 3.9
ARG K2_VERSION="1.24.4.dev20240211+cuda11.6.torch1.13.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.6.torch1.13.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.6.torch1.13.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.6.torch1.13.0"
ARG TORCHAUDIO_VERSION="0.13.0+cu116" ARG TORCHAUDIO_VERSION="0.13.0+cu116"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.7 # python 3.7
ARG K2_VERSION="1.24.4.dev20240211+cuda10.2.torch1.9.0" ARG K2_VERSION="1.24.4.dev20240223+cuda10.2.torch1.9.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda10.2.torch1.9.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda10.2.torch1.9.0"
ARG TORCHAUDIO_VERSION="0.9.0" ARG TORCHAUDIO_VERSION="0.9.0"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda11.7.torch2.0.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.7.torch2.0.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.7.torch2.0.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.7.torch2.0.0"
ARG TORCHAUDIO_VERSION="2.0.0+cu117" ARG TORCHAUDIO_VERSION="2.0.0+cu117"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.1.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.1.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu118" ARG TORCHAUDIO_VERSION="2.1.0+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.1.0" ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.1.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.1.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.1.0"
ARG TORCHAUDIO_VERSION="2.1.0+cu121" ARG TORCHAUDIO_VERSION="2.1.0+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda11.8.torch2.2.0" ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda11.8.torch2.2.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.0"
ARG TORCHAUDIO_VERSION="2.2.0+cu118" ARG TORCHAUDIO_VERSION="2.2.0+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -5,8 +5,8 @@ ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# python 3.10 # python 3.10
ARG K2_VERSION="1.24.4.dev20240211+cuda12.1.torch2.2.0" ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.0"
ARG KALDIFEAT_VERSION="1.25.4.dev20240210+cuda12.1.torch2.2.0" ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.0"
ARG TORCHAUDIO_VERSION="2.2.0+cu121" ARG TORCHAUDIO_VERSION="2.2.0+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>" LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.2.1-cuda11.8-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda11.8.torch2.2.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda11.8.torch2.2.1"
ARG TORCHAUDIO_VERSION="2.2.1+cu118"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -0,0 +1,70 @@
FROM pytorch/pytorch:2.2.1-cuda12.1-cudnn8-devel
ENV LC_ALL C.UTF-8
ARG DEBIAN_FRONTEND=noninteractive
# python 3.10
ARG K2_VERSION="1.24.4.dev20240223+cuda12.1.torch2.2.1"
ARG KALDIFEAT_VERSION="1.25.4.dev20240223+cuda12.1.torch2.2.1"
ARG TORCHAUDIO_VERSION="2.2.1+cu121"
LABEL authors="Fangjun Kuang <csukuangfj@gmail.com>"
LABEL k2_version=${K2_VERSION}
LABEL kaldifeat_version=${KALDIFEAT_VERSION}
LABEL github_repo="https://github.com/k2-fsa/icefall"
RUN apt-get update && \
apt-get install -y --no-install-recommends \
curl \
vim \
libssl-dev \
autoconf \
automake \
bzip2 \
ca-certificates \
ffmpeg \
g++ \
gfortran \
git \
libtool \
make \
patch \
sox \
subversion \
unzip \
valgrind \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/*
# Install dependencies
RUN pip install --no-cache-dir \
torchaudio==${TORCHAUDIO_VERSION} -f https://download.pytorch.org/whl/torch_stable.html \
k2==${K2_VERSION} -f https://k2-fsa.github.io/k2/cuda.html \
git+https://github.com/lhotse-speech/lhotse \
kaldifeat==${KALDIFEAT_VERSION} -f https://csukuangfj.github.io/kaldifeat/cuda.html \
kaldi_native_io \
kaldialign \
kaldifst \
kaldilm \
sentencepiece>=0.1.96 \
tensorboard \
typeguard \
dill \
onnx \
onnxruntime \
onnxmltools \
multi_quantization \
typeguard \
numpy \
pytest \
graphviz
RUN git clone https://github.com/k2-fsa/icefall /workspace/icefall && \
cd /workspace/icefall && \
pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /workspace/icefall:$PYTHONPATH
WORKDIR /workspace/icefall

View File

@ -34,6 +34,8 @@ which will give you something like below:
.. code-block:: bash .. code-block:: bash
"torch2.2.1-cuda12.1"
"torch2.2.1-cuda11.8"
"torch2.2.0-cuda12.1" "torch2.2.0-cuda12.1"
"torch2.2.0-cuda11.8" "torch2.2.0-cuda11.8"
"torch2.1.0-cuda12.1" "torch2.1.0-cuda12.1"

View File

@ -1,11 +1,11 @@
VITS VITS-LJSpeech
=============== ===============
This tutorial shows you how to train an VITS model This tutorial shows you how to train an VITS model
with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset. with the `LJSpeech <https://keithito.com/LJ-Speech-Dataset/>`_ dataset.
.. note:: .. note::
TTS related recipes require packages in ``requirements-tts.txt``. TTS related recipes require packages in ``requirements-tts.txt``.
.. note:: .. note::
@ -120,4 +120,4 @@ Download pretrained models
If you don't want to train from scratch, you can download the pretrained models If you don't want to train from scratch, you can download the pretrained models
by visiting the following link: by visiting the following link:
- `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29>`_ - `<https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28>`_

View File

@ -1,11 +1,11 @@
VITS VITS-VCTK
=============== ===============
This tutorial shows you how to train an VITS model This tutorial shows you how to train an VITS model
with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset. with the `VCTK <https://datashare.ed.ac.uk/handle/10283/3443>`_ dataset.
.. note:: .. note::
TTS related recipes require packages in ``requirements-tts.txt``. TTS related recipes require packages in ``requirements-tts.txt``.
.. note:: .. note::

View File

@ -75,7 +75,7 @@ It's reworked Zipformer with Pruned RNNT loss, trained with Byte-level BPE, `voc
| fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 | | fast beam search | 4.43 | 4.17 | --epoch 40 --avg 10 |
```bash ```bash
./prepare.sh ./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1" export CUDA_VISIBLE_DEVICES="0,1"

View File

@ -250,7 +250,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=1, default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram", help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
) )
parser.add_argument( parser.add_argument(

View File

@ -1,6 +1,6 @@
## Results ## Results
### Aishell2 char-based training results ### Aishell2 char-based training results
#### Pruned transducer stateless 5 #### Pruned transducer stateless 5

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_aishell2(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train", "train",
@ -68,8 +77,12 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
list(manifests.keys()), list(manifests.keys()),
dataset_parts, dataset_parts,
) )
if whisper_fbank:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
@ -82,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -111,7 +124,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -122,5 +140,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell2( compute_fbank_aishell2(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -108,6 +108,16 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
fi fi
whisper_mel_bins=80
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
log "Stage 30: Compute whisper fbank for aishell2"
if [ ! -f data/fbank/.aishell2.whisper.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell2.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.aishell2.whisper.done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Compute fbank for musan" log "Stage 4: Compute fbank for musan"
if [ ! -f data/fbank/.msuan.done ]; then if [ ! -f data/fbank/.msuan.done ]; then

View File

@ -3,7 +3,7 @@
This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets). This recipe contains some various ASR models trained with Aishell4 (including S, M and L three subsets).
The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks. The AISHELL-4 is a sizable real-recorded Mandarin speech dataset collected by 8-channel circular microphone array for speech processing in conference scenarios. The dataset consists of 211 recorded meeting sessions, each containing 4 to 8 speakers, with a total length of 120 hours. This dataset aims to bridge the advanced research on multi-speaker processing and the practical application scenario in three aspects. With real recorded meetings, AISHELL-4 provides realistic acoustics and rich natural speech characteristics in conversation such as short pause, speech overlap, quick speaker turn, noise, etc. Meanwhile, the accurate transcription and speaker voice activity are provided for each meeting in AISHELL-4. This allows the researchers to explore different aspects in meeting processing, ranging from individual tasks such as speech front-end processing, speech recognition and speaker diarization, to multi-modality modeling and joint optimization of relevant tasks.
(From [Open Speech and Language Resources](https://www.openslr.org/111/)) (From [Open Speech and Language Resources](https://www.openslr.org/111/))

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_aishell4(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/aishell4") src_dir = Path("data/manifests/aishell4")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train_S", "train_S",
@ -70,7 +79,12 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
dataset_parts, dataset_parts,
) )
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
@ -84,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -95,7 +109,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False):
# when an executor is specified, make more partitions # when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80, num_jobs=num_jobs if ex is None else 80,
executor=ex, executor=ex,
storage_type=ChunkedLilcomHdf5Writer, storage_type=LilcomChunkyWriter,
) )
logging.info("About splitting cuts into smaller chunks") logging.info("About splitting cuts into smaller chunks")
@ -121,7 +135,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use WhisperFbank instead of Fbank. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -132,5 +151,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell4( compute_fbank_aishell4(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=-1 stage=-1
stop_stage=100 stop_stage=7
perturb_speed=true perturb_speed=true
@ -76,11 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process aishell4" log "Stage 2: Compute fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank/aishell4 mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} ./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/aishell4/.fbank.done touch data/fbank/.fbank.done
fi
fi
whisper_mel_bins=80
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "Stage 20: Compute whisper fbank for aishell4"
if [ ! -f data/fbank/aishell4/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.fbank.done
fi fi
fi fi
@ -106,16 +116,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for aishell4" log "Stage 5: Prepare char based lang"
if [ ! -f data/fbank/.aishell4.done ]; then
mkdir -p data/fbank
./local/compute_fbank_aishell4.py --perturb-speed ${perturb_speed}
touch data/fbank/.aishell4.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char lang_char_dir=data/lang_char
mkdir -p $lang_char_dir mkdir -p $lang_char_dir

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,10 +49,12 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False): def compute_fbank_alimeeting(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests/alimeeting") src_dir = Path("data/manifests/alimeeting")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(8, os.cpu_count())
dataset_parts = ( dataset_parts = (
"train", "train",
@ -53,7 +62,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
"test", "test",
) )
prefix = "alimeeting" prefix = "alimeeting-far"
suffix = "jsonl.gz" suffix = "jsonl.gz"
manifests = read_manifests_if_cached( manifests = read_manifests_if_cached(
dataset_parts=dataset_parts, dataset_parts=dataset_parts,
@ -70,7 +79,12 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
dataset_parts, dataset_parts,
) )
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) if whisper_fbank:
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
@ -83,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -121,7 +135,12 @@ def get_args():
default=False, default=False,
help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.",
) )
parser.add_argument(
"--whisper-fbank",
type=str2bool,
default=False,
help="Use the Whisper Fbank feature extractor. Default: False.",
)
return parser.parse_args() return parser.parse_args()
@ -132,5 +151,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_alimeeting( compute_fbank_alimeeting(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -6,7 +6,7 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail set -eou pipefail
stage=-1 stage=-1
stop_stage=100 stop_stage=7
perturb_speed=true perturb_speed=true
# We assume dl_dir (download dir) contains the following # We assume dl_dir (download dir) contains the following
@ -66,10 +66,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Process alimeeting" log "Stage 2: compute fbank for alimeeting"
if [ ! -f data/fbank/alimeeting/.fbank.done ]; then if [ ! -f data/fbank/.fbank.done ]; then
mkdir -p data/fbank/alimeeting mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} ./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed}
touch data/fbank/.fbank.done
fi
fi
whisper_mel_bins=80
if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then
log "Stage 20: compute whisper fbank for alimeeting"
if [ ! -f data/fbank/.fbank.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.fbank.done
fi fi
fi fi
@ -95,16 +106,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Compute fbank for alimeeting" log "Stage 5: Prepare char based lang"
if [ ! -f data/fbank/.alimeeting.done ]; then
mkdir -p data/fbank
./local/compute_fbank_alimeeting.py --perturb-speed True
touch data/fbank/.alimeeting.done
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Prepare char based lang"
lang_char_dir=data/lang_char lang_char_dir=data/lang_char
mkdir -p $lang_char_dir mkdir -p $lang_char_dir

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/compile_hlg.py

View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -0,0 +1,149 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Kang Wei,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
return LG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "LG.pt").is_file():
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir, args.lm)
logging.info(f"Saving LG.pt to {lang_dir}")
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -48,8 +48,27 @@ def normalize_text(utt: str, language: str) -> str:
utt = re.sub("", "'", utt) utt = re.sub("", "'", utt)
if language == "en": if language == "en":
return re.sub(r"[^a-zA-Z\s]", "", utt).upper() return re.sub(r"[^a-zA-Z\s]", "", utt).upper()
if language == "fr": elif language == "fr":
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper() return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
elif language == "pl":
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
elif language == "yue":
return (
utt.replace(" ", "")
.replace("", "")
.replace("", " ")
.replace("", "")
.replace("", "")
.replace("?", "")
)
else:
raise NotImplementedError(
f"""
Text normalization not implemented for language: {language},
please consider implementing it in the local/preprocess_commonvoice.py
or raise an issue on GitHub to request it.
"""
)
def preprocess_commonvoice( def preprocess_commonvoice(

View File

@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=(
if self.args.on_the_fly_feats OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
else eval(self.args.input_strategy)(), if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(

View File

@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse import argparse
import logging import logging
from icefall import is_module_available import torch
from onnx_pretrained import OnnxModel from onnx_pretrained import OnnxModel
import torch from icefall import is_module_available
def get_parser(): def get_parser():

View File

@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures, PrecomputedFeatures,
SingleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule:
drop_last=self.args.drop_last, drop_last=self.args.drop_last,
) )
else: else:
logging.info("Using SingleCutSampler.") logging.info("Using SimpleCutSampler.")
train_sampler = SingleCutSampler( train_sampler = SimpleCutSampler(
cuts_train, cuts_train,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=(
if self.args.on_the_fly_feats OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
else eval(self.args.input_strategy)(), if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(

View File

@ -70,9 +70,9 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from tokenizer import Tokenizer from tokenizer import Tokenizer
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -23,6 +23,7 @@ from pathlib import Path
from lhotse import CutSet, SupervisionSegment from lhotse import CutSet, SupervisionSegment
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import str2bool from icefall.utils import str2bool
# Similar text filtering and normalization procedure as in: # Similar text filtering and normalization procedure as in:

View File

@ -76,6 +76,7 @@ from beam_search import (
) )
from gigaspeech_scoring import asr_text_post_processing from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,

View File

@ -88,7 +88,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule from asr_datamodule import GigaSpeechAsrDataModule
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -51,7 +51,7 @@ from streaming_beam_search import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -42,12 +42,10 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import ( from beam_search import keywords_search
keywords_search, from lhotse.cut import Cut
)
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from lhotse.cut import Cut
from icefall import ContextGraph from icefall import ContextGraph
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -76,6 +76,20 @@ from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from train import (
add_model_arguments,
add_training_arguments,
compute_loss,
compute_validation_loss,
display_and_save_batch,
get_adjusted_batch_count,
get_model,
get_params,
load_checkpoint_if_available,
save_checkpoint,
scan_pessimistic_batches_for_oom,
set_batch_count,
)
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import remove_checkpoints
@ -95,21 +109,6 @@ from icefall.utils import (
str2bool, str2bool,
) )
from train import (
add_model_arguments,
add_training_arguments,
compute_loss,
compute_validation_loss,
display_and_save_batch,
get_adjusted_batch_count,
get_model,
get_params,
load_checkpoint_if_available,
save_checkpoint,
scan_pessimistic_batches_for_oom,
set_batch_count,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]

View File

@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) input_strategy=(
if self.args.on_the_fly_feats OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
else PrecomputedFeatures(), if self.args.on_the_fly_feats
else PrecomputedFeatures()
),
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(

View File

@ -35,6 +35,7 @@ The following table lists the differences among them.
| `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) | | `lstm_transducer_stateless2` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gigaspeech (multi-dataset setup) |
| `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty | | `lstm_transducer_stateless3` | LSTM | Embedding + Conv1d | Using LSTM with mechanisms in reworked model + gradient filter + delay penalty |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe | | `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
| `zipformer_adapter` | Upgraded Zipformer | Embedding + Conv1d | It supports domain adaptation of Zipformer using parameter efficient adapters |
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -24,8 +24,7 @@ To run this file, do:
""" """
import torch import torch
from train import get_ctc_model, get_params
from train import get_params, get_ctc_model
def test_model(): def test_model():

View File

@ -59,9 +59,9 @@ import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from emformer import Emformer from emformer import Emformer
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -39,7 +39,7 @@ Usage of this script:
import argparse import argparse
import logging import logging
import math import math
from typing import List from typing import List, Optional
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
@ -47,7 +47,6 @@ import torch
import torchaudio import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from typing import Optional, List
def get_parser(): def get_parser():

View File

@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
""" """
import argparse import argparse
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import logging import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Tuple
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import ( from beam_search import (
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
from lhotse import CutSet, load_manifest_lazy from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.supervision import AlignmentItem
from lhotse.serialization import SequentialJsonlWriter from lhotse.serialization import SequentialJsonlWriter
from lhotse.supervision import AlignmentItem
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
def get_parser(): def get_parser():

View File

@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
import argparse import argparse
import logging import logging
import torch
from onnx_pretrained import OnnxModel from onnx_pretrained import OnnxModel
from icefall import is_module_available from icefall import is_module_available
import torch
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
import argparse import argparse
import logging import logging
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from train import add_model_arguments, get_encoder_model, get_params
from icefall.profiler import get_model_profile from icefall.profiler import get_model_profile
from train import get_encoder_model, add_model_arguments, get_params
def get_parser(): def get_parser():

View File

@ -75,8 +75,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import OnnxModel, greedy_search
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
import argparse import argparse
import logging import logging
from icefall import is_module_available import torch
from onnx_pretrained import OnnxModel from onnx_pretrained import OnnxModel
import torch from icefall import is_module_available
def get_parser(): def get_parser():

View File

@ -76,8 +76,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from librispeech import LibriSpeech from librispeech import LibriSpeech
from onnx_pretrained import OnnxModel, greedy_search
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
import argparse import argparse
import logging import logging
from typing import Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from scaling import BasicNorm, DoubleSwish
from typing import Tuple
from torch import Tensor, nn from torch import Tensor, nn
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
from icefall.profiler import get_model_profile from icefall.profiler import get_model_profile
from scaling import BasicNorm, DoubleSwish
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
def get_parser(): def get_parser():

View File

@ -82,8 +82,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import OnnxModel, greedy_search
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -20,7 +20,6 @@ from typing import List
import k2 import k2
import torch import torch
from beam_search import Hypothesis, HypothesisList, get_hyps_shape from beam_search import Hypothesis, HypothesisList, get_hyps_shape
# The force alignment problem can be formulated as finding # The force alignment problem can be formulated as finding

View File

@ -107,9 +107,6 @@ import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
# from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest, fast_beam_search_nbest,
@ -120,6 +117,9 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
# from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule
from gigaspeech_scoring import asr_text_post_processing from gigaspeech_scoring import asr_text_post_processing
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model

View File

@ -65,16 +65,15 @@ from typing import Dict, List
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool
def get_parser(): def get_parser():

View File

@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
import argparse import argparse
import logging import logging
from typing import Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from scaling import BasicNorm, DoubleSwish
from typing import Tuple
from torch import Tensor, nn from torch import Tensor, nn
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
from icefall.profiler import get_model_profile from icefall.profiler import get_model_profile
from scaling import BasicNorm, DoubleSwish
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
def get_parser(): def get_parser():

View File

@ -75,8 +75,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from onnx_pretrained import OnnxModel, greedy_search
from onnx_pretrained import greedy_search, OnnxModel
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats

View File

@ -24,7 +24,6 @@ To run this file, do:
""" """
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model from train import get_params, get_transducer_model

View File

@ -118,8 +118,8 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import add_model_arguments, get_params, get_transducer_model
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import ( from scaling import ActivationBalancer, ScaledConv1d
ActivationBalancer,
ScaledConv1d,
)
class LConv(nn.Module): class LConv(nn.Module):

View File

@ -52,7 +52,7 @@ import onnxruntime as ort
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask

View File

@ -14,6 +14,7 @@
import torch import torch
from torch import nn from torch import nn
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask

View File

@ -4,7 +4,6 @@
import ncnn import ncnn
import numpy as np import numpy as np
layer_list = [] layer_list = []

View File

@ -42,7 +42,6 @@ import ncnn
import torch import torch
import torchaudio import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
from ncnn_custom_layer import RegisterCustomLayers from ncnn_custom_layer import RegisterCustomLayers

View File

@ -1,10 +1,11 @@
import argparse import argparse
import logging import logging
import math import math
import pprint
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import pprint
import k2 import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch

View File

@ -88,7 +88,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -22,9 +22,9 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module): class AsrModel(nn.Module):

View File

@ -22,24 +22,24 @@ Usage: ./zipformer/my_profile.py
import argparse import argparse
import logging import logging
from typing import Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from typing import Tuple
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.profiler import get_model_profile
from scaling import BiasNorm from scaling import BiasNorm
from torch import Tensor, nn
from train import ( from train import (
add_model_arguments,
get_encoder_embed, get_encoder_embed,
get_encoder_model, get_encoder_model,
get_joiner_model, get_joiner_model,
add_model_arguments,
get_params, get_params,
) )
from zipformer import BypassModule from zipformer import BypassModule
from icefall.profiler import get_model_profile
from icefall.utils import make_pad_mask
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -77,11 +77,10 @@ from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from k2 import SymbolTable
from onnx_pretrained import greedy_search, OnnxModel from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats from icefall.utils import setup_logger, store_transcripts, write_error_stats
from k2 import SymbolTable
def get_parser(): def get_parser():

View File

@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
import argparse import argparse
import logging import logging
import math import math
from typing import List, Tuple from typing import Dict, List, Tuple
import k2 import k2
import kaldifeat import kaldifeat
from typing import Dict
import kaldifst import kaldifst
import onnxruntime as ort import onnxruntime as ort
import torch import torch

View File

@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
import argparse import argparse
import logging import logging
import math import math
from typing import List, Tuple from typing import Dict, List, Tuple
import k2 import k2
import kaldifeat import kaldifeat
from typing import Dict
import kaldifst import kaldifst
import onnxruntime as ort import onnxruntime as ort
import torch import torch

View File

@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
import argparse import argparse
import logging import logging
import math import math
from typing import List, Tuple from typing import Dict, List, Tuple
import k2 import k2
import kaldifeat import kaldifeat
from typing import Dict
import kaldifst import kaldifst
import onnxruntime as ort import onnxruntime as ort
import torch import torch

View File

@ -15,15 +15,16 @@
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union
import logging import logging
import k2
from torch.cuda.amp import custom_fwd, custom_bwd
import random
import torch
import math import math
import random
from typing import Optional, Tuple, Union
import k2
import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:

View File

@ -51,7 +51,7 @@ from streaming_beam_search import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -16,11 +16,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Tuple
import warnings import warnings
from typing import Tuple
import torch import torch
from torch import Tensor, nn
from scaling import ( from scaling import (
Balancer, Balancer,
BiasNorm, BiasNorm,
@ -34,6 +33,7 @@ from scaling import (
SwooshR, SwooshR,
Whiten, Whiten,
) )
from torch import Tensor, nn
class ConvNeXt(nn.Module): class ConvNeXt(nn.Module):

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/beam_search.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../zipformer/decoder.py

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1,621 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import k2
import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_finetune_arguments, add_model_arguments, get_model, get_params
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
add_finetune_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "non-streaming zipformer2",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../zipformer/joiner.py

View File

@ -0,0 +1 @@
../zipformer/model.py

View File

@ -0,0 +1,386 @@
#!/usr/bin/env python3
#
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads ONNX exported models and uses them to decode the test sets.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--causal False
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
2. Run this file
./zipformer/onnx_decode.py \
--exp-dir $repo/exp \
--max-duration 600 \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from k2 import SymbolTable
from onnx_pretrained import OnnxModel, greedy_search
from icefall.utils import setup_logger, store_transcripts, write_error_stats
conversational_filler = [
"UH",
"UHH",
"UM",
"EH",
"MM",
"HM",
"AH",
"HUH",
"HA",
"ER",
"OOF",
"HEE",
"ACH",
"EEE",
"EW",
]
unk_tags = ["<UNK>", "<unk>"]
gigaspeech_punctuations = [
"<COMMA>",
"<PERIOD>",
"<QUESTIONMARK>",
"<EXCLAMATIONPOINT>",
]
gigaspeech_garbage_utterance_tags = ["<SIL>", "<NOISE>", "<MUSIC>", "<OTHER>"]
non_scoring_words = (
conversational_filler
+ unk_tags
+ gigaspeech_punctuations
+ gigaspeech_garbage_utterance_tags
)
def asr_text_post_processing(text: str) -> str:
# 1. convert to uppercase
text = text.upper()
# 2. remove hyphen
# "E-COMMERCE" -> "E COMMERCE", "STATE-OF-THE-ART" -> "STATE OF THE ART"
text = text.replace("-", " ")
# 3. remove non-scoring words from evaluation
remaining_words = []
for word in text.split():
if word in non_scoring_words:
continue
remaining_words.append(word)
return " ".join(remaining_words)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="Valid values are greedy_search and modified_beam_search",
)
return parser
def post_processing(
results: List[Tuple[str, List[str], List[str]]],
) -> List[Tuple[str, List[str], List[str]]]:
new_results = []
for key, ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((key, new_ref, new_hyp))
return new_results
def decode_one_batch(
model: OnnxModel, token_table: SymbolTable, batch: dict
) -> List[List[str]]:
"""Decode one batch and return the result.
Currently it only greedy_search is supported.
Args:
model:
The neural model.
token_table:
The token table.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoded results for each utterance.
"""
feature = batch["inputs"]
assert feature.ndim == 3
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(dtype=torch.int64)
encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens)
hyps = greedy_search(
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
hyps = [token_ids_to_words(h).split() for h in hyps]
return hyps
def decode_dataset(
dl: torch.utils.data.DataLoader,
model: nn.Module,
token_table: SymbolTable,
) -> Tuple[List[Tuple[str, List[str], List[str]]], float]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
model:
The neural model.
token_table:
The token table.
Returns:
- A list of tuples. Each tuple contains three elements:
- cut_id,
- reference transcript,
- predicted result.
- The total duration (in seconds) of the dataset.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 10
total_duration = 0
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]])
hyps = decode_one_batch(model=model, token_table=token_table, batch=batch)
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results.extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, total_duration
def save_results(
res_dir: Path,
test_set_name: str,
results: List[Tuple[str, List[str], List[str]]],
):
recog_path = res_dir / f"recogs-{test_set_name}.txt"
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = res_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
errs_info = res_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("WER", file=f)
print(wer, file=f)
s = "\nFor {}, WER is {}:\n".format(test_set_name, wer)
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
assert (
args.decoding_method == "greedy_search"
), "Only supports greedy_search currently."
res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}"
setup_logger(f"{res_dir}/log-decode")
logging.info("Decoding started")
device = torch.device("cpu")
logging.info(f"Device: {device}")
token_table = SymbolTable.from_file(args.tokens)
logging.info(vars(args))
logging.info("About to create model")
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
gigaspeech_dev_cuts = librispeech.gigaspeech_dev_cuts()
gigaspeech_test_cuts = librispeech.gigaspeech_test_cuts()
dev_dl = librispeech.test_dataloaders(gigaspeech_dev_cuts)
test_dl = librispeech.test_dataloaders(gigaspeech_test_cuts)
test_sets = ["dev", "test"]
test_dl = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(
dl=test_dl, model=model, token_table=token_table
)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration
logging.info(f"Elapsed time: {elapsed_seconds:.3f} s")
logging.info(f"Wave duration: {total_duration:.3f} s")
logging.info(
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
)
save_results(res_dir=res_dir, test_set_name=test_set, results=results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../zipformer/onnx_pretrained.py

View File

@ -0,0 +1 @@
../zipformer/optim.py

View File

@ -0,0 +1 @@
../zipformer/scaling.py

View File

@ -0,0 +1 @@
../zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../zipformer/subsampling.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -31,6 +31,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir conformer_ctc2/exp \ --exp-dir conformer_ctc2/exp \
--lang-dir data/lang_bpe_200 \ --lang-dir data/lang_bpe_200 \
--otc-token "<star>" \ --otc-token "<star>" \
--feature-dim 768 \
--allow-bypass-arc true \ --allow-bypass-arc true \
--allow-self-loop-arc true \ --allow-self-loop-arc true \
--initial-bypass-weight -19 \ --initial-bypass-weight -19 \
@ -160,6 +161,14 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--feature-dim",
type=int,
default=768,
help="""Number of features extracted in feature extraction stage.last dimension of feature vector.
80 when using fbank features and 768 or 1024 whn using wave2vec""",
)
parser.add_argument( parser.add_argument(
"--initial-lr", "--initial-lr",
type=float, type=float,
@ -385,7 +394,6 @@ def get_params() -> AttributeDict:
"valid_interval": 800, # For the 100h subset, use 800 "valid_interval": 800, # For the 100h subset, use 800
"alignment_interval": 25, "alignment_interval": 25,
# parameters for conformer # parameters for conformer
"feature_dim": 768,
"subsampling_factor": 2, "subsampling_factor": 2,
"encoder_dim": 512, "encoder_dim": 512,
"nhead": 8, "nhead": 8,

View File

@ -0,0 +1,38 @@
# Introduction
This is a public domain speech dataset consisting of 13,100 short audio clips of a single speaker reading passages from 7 non-fiction books.
A transcription is provided for each clip.
Clips vary in length from 1 to 10 seconds and have a total length of approximately 24 hours.
The texts were published between 1884 and 1964, and are in the public domain.
The audio was recorded in 2016-17 by the [LibriVox](https://librivox.org/) project and is also in the public domain.
The above information is from the [LJSpeech website](https://keithito.com/LJ-Speech-Dataset/).
# VITS
This recipe provides a VITS model trained on the LJSpeech dataset.
Pretrained model can be found [here](https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28).
For tutorial and more details, please refer to the [VITS documentation](https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html).
The training command is given below:
```
export CUDA_VISIBLE_DEVICES=0,1,2,3
./vits/train.py \
--world-size 4 \
--num-epochs 1000 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir vits/exp \
--max-duration 500
```
To inference, use:
```
./vits/infer.py \
--exp-dir vits/exp \
--epoch 1000 \
--tokens data/tokens.txt
```

View File

@ -17,7 +17,7 @@
""" """
This file reads the texts in given manifest and generates the file that maps tokens to IDs. This file generates the file that maps tokens to IDs.
""" """
import argparse import argparse
@ -25,80 +25,38 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
from lhotse import load_manifest from piper_phonemize import get_espeak_map
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-file",
type=Path,
default=Path("data/spectrogram/ljspeech_cuts_train.jsonl.gz"),
help="Path to the manifest file",
)
parser.add_argument( parser.add_argument(
"--tokens", "--tokens",
type=Path, type=Path,
default=Path("data/tokens.txt"), default=Path("data/tokens.txt"),
help="Path to the tokens", help="Path to the dict that maps the text tokens to IDs",
) )
return parser.parse_args() return parser.parse_args()
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: def get_token2id(filename: Path) -> Dict[str, int]:
"""Write a symbol to ID mapping to a file. """Get a dict that maps token to IDs, and save it to the given filename."""
all_tokens = get_espeak_map() # token: [token_id]
all_tokens = {token: token_id[0] for token, token_id in all_tokens.items()}
# sort by token_id
all_tokens = sorted(all_tokens.items(), key=lambda x: x[1])
Note:
No need to implement `read_mapping` as it can be done
through :func:`k2.SymbolTable.from_file`.
Args:
filename:
Filename to save the mapping.
sym2id:
A dict mapping symbols to IDs.
Returns:
Return None.
"""
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
for sym, i in sym2id.items(): for token, token_id in all_tokens:
f.write(f"{sym} {i}\n") f.write(f"{token} {token_id}\n")
def get_token2id(manifest_file: Path) -> Dict[str, int]:
"""Return a dict that maps token to IDs."""
extra_tokens = [
"<blk>", # 0 for blank
"<sos/eos>", # 1 for sos and eos symbols.
"<unk>", # 2 for OOV
]
all_tokens = set()
cut_set = load_manifest(manifest_file)
for cut in cut_set:
# Each cut only contain one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions)
for t in cut.tokens:
all_tokens.add(t)
all_tokens = extra_tokens + list(all_tokens)
token2id: Dict[str, int] = {token: i for i, token in enumerate(all_tokens)}
return token2id
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args() args = get_args()
manifest_file = Path(args.manifest_file)
out_file = Path(args.tokens) out_file = Path(args.tokens)
get_token2id(out_file)
token2id = get_token2id(manifest_file)
write_mapping(out_file, token2id)

View File

@ -23,9 +23,9 @@ This file reads the texts in given manifest and save the new cuts with phoneme t
import logging import logging
from pathlib import Path from pathlib import Path
import g2p_en
import tacotron_cleaner.cleaners import tacotron_cleaner.cleaners
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from piper_phonemize import phonemize_espeak
def prepare_tokens_ljspeech(): def prepare_tokens_ljspeech():
@ -35,17 +35,20 @@ def prepare_tokens_ljspeech():
partition = "all" partition = "all"
cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}")
g2p = g2p_en.G2p()
new_cuts = [] new_cuts = []
for cut in cut_set: for cut in cut_set:
# Each cut only contains one supervision # Each cut only contains one supervision
assert len(cut.supervisions) == 1, len(cut.supervisions) assert len(cut.supervisions) == 1, (len(cut.supervisions), cut)
text = cut.supervisions[0].normalized_text text = cut.supervisions[0].normalized_text
# Text normalization # Text normalization
text = tacotron_cleaner.cleaners.custom_english_cleaners(text) text = tacotron_cleaner.cleaners.custom_english_cleaners(text)
# Convert to phonemes # Convert to phonemes
cut.tokens = g2p(text) tokens_list = phonemize_espeak(text, "en-us")
tokens = []
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
new_cuts.append(cut) new_cuts.append(cut)
new_cut_set = CutSet.from_cuts(new_cuts) new_cut_set = CutSet.from_cuts(new_cuts)

View File

@ -30,7 +30,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
cd vits/monotonic_align cd vits/monotonic_align
python setup.py build_ext --inplace python setup.py build_ext --inplace
cd ../../ cd ../../
else else
log "monotonic_align lib already built" log "monotonic_align lib already built"
fi fi
fi fi
@ -80,6 +80,11 @@ fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare phoneme tokens for LJSpeech" log "Stage 3: Prepare phoneme tokens for LJSpeech"
# We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with:
# - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then if [ ! -e data/spectrogram/.ljspeech_with_token.done ]; then
./local/prepare_tokens_ljspeech.py ./local/prepare_tokens_ljspeech.py
mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \ mv data/spectrogram/ljspeech_cuts_with_tokens_all.jsonl.gz \
@ -113,13 +118,12 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Generate token file" log "Stage 5: Generate token file"
# We assume you have installed g2p_en and espnet_tts_frontend. # We assume you have installed piper_phonemize and espnet_tts_frontend.
# If not, please install them with: # If not, please install them with:
# - g2p_en: `pip install g2p_en`, refer to https://github.com/Kyubyong/g2p # - piper_phonemize: refer to https://github.com/rhasspy/piper-phonemize,
# could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5
# - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ # - espnet_tts_frontend, `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/
if [ ! -e data/tokens.txt ]; then if [ ! -e data/tokens.txt ]; then
./local/prepare_token_file.py \ ./local/prepare_token_file.py --tokens data/tokens.txt
--manifest-file data/spectrogram/ljspeech_cuts_train.jsonl.gz \
--tokens data/tokens.txt
fi fi
fi fi

Some files were not shown because too many files have changed in this diff Show More