From 9ee8a4defaf22758bcb1df3f4e2f0a711772ac1c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 2 Mar 2022 14:07:38 +0800 Subject: [PATCH] Update results. --- ...ransducer-stateless-modified-2-aishell.yml | 2 - ...-transducer-stateless-modified-aishell.yml | 153 ++++++++ egs/aishell/ASR/RESULTS.md | 69 +++- .../transducer_stateless_modified-2/decode.py | 4 +- .../pretrained.py | 3 +- .../transducer_stateless_modified/decode.py | 4 +- .../transducer_stateless_modified/export.py | 246 +++++++++++++ .../pretrained.py | 331 ++++++++++++++++++ .../transducer_stateless_modified/train.py | 11 +- 9 files changed, 811 insertions(+), 12 deletions(-) create mode 100644 .github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml create mode 100755 egs/aishell/ASR/transducer_stateless_modified/export.py create mode 100755 egs/aishell/ASR/transducer_stateless_modified/pretrained.py diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 763a55665..c27ffc374 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -76,8 +76,6 @@ jobs: git lfs install git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2-2022-03-01 - - cd .. tree tmp soxi tmp/icefall-aishell-transducer-stateless-modified-2-2022-03-01/test_wavs/*.wav diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml new file mode 100644 index 000000000..2e38abb5a --- /dev/null +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -0,0 +1,153 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-pre-trained-trandsucer-stateless-modified-aishell + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_pre_trained_transducer_stateless_modified_aishell: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.9.dev20211101"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip pytest + # numpy 1.20.x does not support python 3.6 + pip install numpy==1.19 + pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + + python3 -m pip install git+https://github.com/lhotse-speech/lhotse + python3 -m pip install kaldifeat + # We are in ./icefall and there is a file: requirements.txt in it + pip install -r requirements.txt + + - name: Install graphviz + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Download pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + cd egs/aishell/ASR + mkdir tmp + cd tmp + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-aishell-transducer-stateless-modified-2022-03-01 + + cd .. + tree tmp + soxi tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav + ls -lh tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/*.wav + + - name: Run greedy search decoding (max-sym-per-frame 1) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/aishell/ASR + ./transducer_stateless_modified/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 1 \ + --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ + --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav + + - name: Run greedy search decoding (max-sym-per-frame 2) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/aishell/ASR + ./transducer_stateless_modified/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 2 \ + --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ + --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav + + - name: Run greedy search decoding (max-sym-per-frame 3) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + cd egs/aishell/ASR + ./transducer_stateless_modified/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 3 \ + --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ + --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav + + - name: Run beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/aishell/ASR + ./transducer_stateless_modified/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ + --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav + + + - name: Run modified beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/aishell/ASR + ./transducer_stateless_modified/pretrained.py \ + --method modified_beam_search \ + --beam-size 4 \ + --checkpoint ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/exp/pretrained.pt \ + --lang-dir ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/data/lang_char \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0121.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0122.wav \ + ./tmp/icefall-aishell-transducer-stateless-modified-2022-03-01/test_wavs/BAC009S0764W0123.wav diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index d8e2e5fc4..ecc93c21b 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -69,9 +69,76 @@ for epoch in 89; do done ``` -You can find a pre-trained model and decoding logs and results at +You can find a pre-trained model, decoding logs, and decoding results at +#### 2022-03-01 + +[./transducer_stateless_modified](./transducer_stateless_modified) + +Stateless transducer + modified transducer. + +| | test |comment | +|------------------------|------|----------------------------------------------------------------| +| greedy search | 5.22 |--epoch 64, --avg 33, --max-duration 100, --max-sym-per-frame 1 | +| modified beam search | 5.02 |--epoch 64, --avg 33, --max-duration 100 --beam-size 4 | + +The training commands are: + +```bash +cd egs/aishell/ASR +./prepare.sh --stop-stage 6 + +export CUDA_VISIBLE_DEVICES="0,1,2" + +./transducer_stateless_modified/train.py \ + --world-size 3 \ + --num-epochs 90 \ + --start-epoch 0 \ + --exp-dir transducer_stateless_modified/exp-4 \ + --max-duration 250 \ + --lr-factor 2.0 \ + --context-size 2 \ + --modified-transducer-prob 0.25 +``` + +The tensorboard log is available at + + +The commands for decoding are + +```bash +# greedy search +for epoch in 64; do + for avg in 33; do + ./transducer_stateless_modified/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless_modified/exp-4 \ + --max-duration 100 \ + --context-size 2 \ + --decoding-method greedy_search \ + --max-sym-per-frame 1 + done +done + +# modified beam search +for epoch in 64; do + for avg in 33; do + ./transducer_stateless_modified/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir transducer_stateless_modified/exp-4 \ + --max-duration 100 \ + --context-size 2 \ + --decoding-method modified_beam_search \ + --beam-size 4 + done +done +``` + +You can find a pre-trained model, decoding logs, and decoding results at + #### 2022-2-19 diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py index f34f8d1c6..8b851bd17 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/decode.py @@ -448,7 +448,9 @@ def main(): filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) model.to(device) model.eval() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index 142a02190..31bab122c 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -129,7 +129,8 @@ def get_parser(): "--max-sym-per-frame", type=int, default=3, - help="Maximum number of symbols per frame", + help="Maximum number of symbols per frame. " + "Use only when --method is greedy_search", ) return parser diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index 822f0a00f..5b5fe6ffa 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -19,8 +19,8 @@ Usage: (1) greedy search ./transducer_stateless_modified/decode.py \ - --epoch 14 \ - --avg 7 \ + --epoch 64 \ + --avg 33 \ --exp-dir ./transducer_stateless_modified/exp \ --max-duration 100 \ --decoding-method greedy_search diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py new file mode 100755 index 000000000..9a20fab6f --- /dev/null +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./transducer_stateless_modified/export.py \ + --exp-dir ./transducer_stateless_modified/exp \ + --epoch 64 \ + --avg 33 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `transducer_stateless_modified/decode.py`, +you can do:: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/aishell/ASR + ./transducer_stateless_modified/decode.py \ + --exp-dir ./transducer_stateless_modified/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --lang-dir data/lang_char +""" + +import argparse +import logging +from pathlib import Path + +import torch +import torch.nn as nn +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default=Path("transducer_stateless_modified/exp"), + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=Path("data/lang_char"), + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.encoder_out_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def main(): + args = get_parser().parse_args() + + assert args.jit is False, "torchscript support will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py new file mode 100755 index 000000000..698594e92 --- /dev/null +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +# greedy search +./transducer_stateless_modified/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +# beam search +./transducer_stateless_modified/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +# modified beam search +./transducer_stateless_modified/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import List + +import kaldifeat +import torch +import torch.nn as nn +import torchaudio +from beam_search import beam_search, greedy_search, modified_beam_search +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer +from torch.nn.utils.rnn import pad_sequence + +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=Path("data/lang_char"), + help="The lang dir", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="Used only when --method is beam_search and modified_beam_search", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="Maximum number of symbols per frame. " + "Use only when --method is greedy_search", + ) + return parser + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "env_info": get_env_info(), + "sample_rate": 16000, + } + ) + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.encoder_out_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lens = [f.size(0) for f in features] + feature_lens = torch.tensor(feature_lens, device=device) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + hyps = [] + with torch.no_grad(): + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + + for i in range(encoder_out.size(0)): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + elif params.method == "modified_beam_search": + hyp = modified_beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.method}" + ) + hyps.append([lexicon.token_table[i] for i in hyp]) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py index 1ffa51c8e..524854b73 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified/train.py @@ -21,16 +21,17 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3" +export CUDA_VISIBLE_DEVICES="0,1,2" ./transducer_stateless_modified/train.py \ - --world-size 4 \ - --num-epochs 30 \ + --world-size 3 \ + --num-epochs 65 \ --start-epoch 0 \ --exp-dir transducer_stateless_modified/exp \ - --full-libri 1 \ --max-duration 250 \ - --lr-factor 2.5 + --lr-factor 2.0 \ + --context-size 2 \ + --modified-transducer-prob 0.25 """