From bb7f6ed6b74df6e3c1b4ae31ae54c3f0cd32b705 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 12 Mar 2022 16:16:55 +0800 Subject: [PATCH] Add modified beam search for pruned rnn-t. (#248) * Add modified beam search for pruned rnn-t. * Fix style issues. * Update RESULTS.md. * Fix typos. * Minor fixes. * Test the pre-trained model using GitHub actions. * Let the user install optimized_transducer on her own. * Fix errors in GitHub CI. --- .../workflows/run-librispeech-2022-03-12.yml | 157 ++++++++++++++++++ README.md | 2 +- egs/librispeech/ASR/README.md | 1 + egs/librispeech/ASR/RESULTS.md | 108 +++++++++++- .../beam_search.py | 144 ++++++++++++++-- .../ASR/pruned_transducer_stateless/decode.py | 105 +++--------- .../ASR/pruned_transducer_stateless/export.py | 76 +-------- .../pruned_transducer_stateless/pretrained.py | 99 +++-------- requirements.txt | 1 - 9 files changed, 439 insertions(+), 254 deletions(-) create mode 100644 .github/workflows/run-librispeech-2022-03-12.yml diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml new file mode 100644 index 000000000..74052312e --- /dev/null +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -0,0 +1,157 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-03-12 +# stateless transducer + k2 pruned rnnt-loss + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +jobs: + run_librispeech_2022_03_12: + if: github.event.label.name == 'ready' || github.event_name == 'push' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.9.dev20211101"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Python dependencies + run: | + python3 -m pip install --upgrade pip pytest + # numpy 1.20.x does not support python 3.6 + pip install numpy==1.19 + pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ + + python3 -m pip install git+https://github.com/lhotse-speech/lhotse + python3 -m pip install kaldifeat + # We are in ./icefall and there is a file: requirements.txt in it + pip install -r requirements.txt + + - name: Install graphviz + shell: bash + run: | + python3 -m pip install -qq graphviz + sudo apt-get -qq install graphviz + + - name: Download pre-trained model + shell: bash + run: | + sudo apt-get -qq install git-lfs tree sox + cd egs/librispeech/ASR + mkdir tmp + cd tmp + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + cd .. + tree tmp + soxi tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav + ls -lh tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav + + - name: Run greedy search decoding (max-sym-per-frame 1) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + cd egs/librispeech/ASR + ./pruned_transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 1 \ + --checkpoint $dir/exp/pretrained.pt \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + $dir/test_wavs/1089-134686-0001.wav \ + $dir/test_wavs/1221-135766-0001.wav \ + $dir/test_wavs/1221-135766-0002.wav + + - name: Run greedy search decoding (max-sym-per-frame 2) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + cd egs/librispeech/ASR + ./pruned_transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 2 \ + --checkpoint $dir/exp/pretrained.pt \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + $dir/test_wavs/1089-134686-0001.wav \ + $dir/test_wavs/1221-135766-0001.wav \ + $dir/test_wavs/1221-135766-0002.wav + + - name: Run greedy search decoding (max-sym-per-frame 3) + shell: bash + run: | + export PYTHONPATH=$PWD:PYTHONPATH + dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + cd egs/librispeech/ASR + ./pruned_transducer_stateless/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame 3 \ + --checkpoint $dir/exp/pretrained.pt \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + $dir/test_wavs/1089-134686-0001.wav \ + $dir/test_wavs/1221-135766-0001.wav \ + $dir/test_wavs/1221-135766-0002.wav + + - name: Run beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + cd egs/librispeech/ASR + ./pruned_transducer_stateless/pretrained.py \ + --method beam_search \ + --beam-size 4 \ + --checkpoint $dir/exp/pretrained.pt \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + $dir/test_wavs/1089-134686-0001.wav \ + $dir/test_wavs/1221-135766-0001.wav \ + $dir/test_wavs/1221-135766-0002.wav + + - name: Run modified beam search decoding + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12 + cd egs/librispeech/ASR + ./pruned_transducer_stateless/pretrained.py \ + --method modified_beam_search \ + --beam-size 4 \ + --checkpoint $dir/exp/pretrained.pt \ + --bpe-model $dir/data/lang_bpe_500/bpe.model \ + $dir/test_wavs/1089-134686-0001.wav \ + $dir/test_wavs/1221-135766-0001.wav \ + $dir/test_wavs/1221-135766-0002.wav diff --git a/README.md b/README.md index a49b30df0..79d8039ff 100644 --- a/README.md +++ b/README.md @@ -84,7 +84,7 @@ The best WER using modified beam search with beam size 4 is: | | test-clean | test-other | |-----|------------|------------| -| WER | 2.61 | 6.46 | +| WER | 2.56 | 6.27 | Note: No auxiliary losses are used in the training and no LMs are used in the decoding. diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index 30b5c5c6f..a7b2e2c3b 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -15,6 +15,7 @@ The following table lists the differences among them. | `transducer_stateless` | Conformer | Embedding + Conv1d | | | `transducer_lstm` | LSTM | LSTM | | | `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | +| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index cc2aebac1..6dbc659f7 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -2,12 +2,111 @@ ### LibriSpeech BPE training results (Pruned Transducer) -#### Conformer encoder + embedding decoder - Conformer encoder + non-current decoder. The decoder contains only an embedding layer, a Conv1d (with kernel size 2) and a linear layer (to transform tensor dim). +#### 2022-03-12 + +[pruned_transducer_stateless](./pruned_transducer_stateless) + +Using commit `1603744469d167d848e074f2ea98c587153205fa`. +See + +The WERs are: + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|------------------------------------------| +| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 | +| greedy search (max sym per frame 2) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 | +| greedy search (max sym per frame 3) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 | +| modified beam search (beam size 4) | 2.56 | 6.27 | --epoch 42, --avg 11, --max-duration 100 | +| beam search (beam size 4) | 2.57 | 6.27 | --epoch 42, --avg 11, --max-duration 100 | + +The decoding time for `test-clean` and `test-other` is given below: +(A V100 GPU with 32 GB RAM is used for decoding. Note: Not all GPU RAM is used during decoding.) + +| decoding method | test-clean (seconds) | test-other (seconds)| +|---|---:|---:| +| greedy search (--max-sym-per-frame=1) | 160 | 159 | +| greedy search (--max-sym-per-frame=2) | 184 | 177 | +| greedy search (--max-sym-per-frame=3) | 210 | 213 | +| modified beam search (--beam-size 4)| 273 | 269 | +|beam search (--beam-size 4) | 2741 | 2221 | + +We recommend you to use `modified_beam_search`. + +Training command: + +```bash +cd egs/librispeech/ASR/ +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +. path.sh + +./pruned_transducer_stateless/train.py \ + --world-size 8 \ + --num-epochs 60 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 300 \ + --prune-range 5 \ + --lr-factor 5 \ + --lm-scale 0.25 +``` + +The tensorboard training log can be found at + + +The command for decoding is: + +```bash +epoch=42 +avg=11 +sym=1 + +# greedy search + +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method greedy_search \ + --beam-size 4 \ + --max-sym-per-frame $sym + +# modified beam search +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# beam search +# (not recommended) +./pruned_transducer_stateless/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 +``` + +You can find a pre-trained model, decoding logs, and decoding results at + + +#### 2022-02-18 + +[pruned_transducer_stateless](./pruned_transducer_stateless) + + The WERs are | | test-clean | test-other | comment | @@ -62,7 +161,7 @@ See ##### 2022-03-01 -Using commit `fill in it after merging`. +Using commit `2332ba312d7ce72f08c7bac1e3312f7e3dd722dc`. It uses [GigaSpeech](https://github.com/SpeechColab/GigaSpeech) as extra training data. 20% of the time it selects a batch from L subset of @@ -129,6 +228,9 @@ sym=1 --beam-size 4 ``` +You can find a pretrained model by visiting + + ##### 2022-02-07 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 3d4818509..38ab16507 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -17,7 +17,6 @@ from dataclasses import dataclass from typing import Dict, List, Optional -import numpy as np import torch from model import Transducer @@ -48,7 +47,7 @@ def greedy_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -103,8 +102,9 @@ class Hypothesis: # Newly predicted tokens are appended to `ys`. ys: List[int] - # The log prob of ys - log_prob: float + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor @property def key(self) -> str: @@ -113,7 +113,7 @@ class Hypothesis: class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: """ Args: data: @@ -125,10 +125,10 @@ class HypothesisList(object): self._data = data @property - def data(self): + def data(self) -> Dict[str, Hypothesis]: return self._data - def add(self, hyp: Hypothesis): + def add(self, hyp: Hypothesis) -> None: """Add a Hypothesis to `self`. If `hyp` already exists in `self`, its probability is updated using @@ -140,8 +140,10 @@ class HypothesisList(object): """ key = hyp.key if key in self: - old_hyp = self._data[key] - old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) + old_hyp = self._data[key] # shallow copy + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -153,7 +155,8 @@ class HypothesisList(object): length_norm: If True, the `log_prob` of a hypothesis is normalized by the number of tokens in it. - + Returns: + Return the hypothesis that has the largest `log_prob`. """ if length_norm: return max( @@ -165,6 +168,9 @@ class HypothesisList(object): def remove(self, hyp: Hypothesis) -> None: """Remove a given hypothesis. + Caution: + `self` is modified **in-place**. + Args: hyp: The hypothesis to be removed from `self`. @@ -175,7 +181,7 @@ class HypothesisList(object): assert key in self, f"{key} does not exist" del self._data[key] - def filter(self, threshold: float) -> "HypothesisList": + def filter(self, threshold: torch.Tensor) -> "HypothesisList": """Remove all Hypotheses whose log_prob is less than threshold. Caution: @@ -183,10 +189,10 @@ class HypothesisList(object): Returns: Return a new HypothesisList containing all hypotheses from `self` - that have `log_prob` being greater than the given `threshold`. + with `log_prob` being greater than the given `threshold`. """ ans = HypothesisList() - for key, hyp in self._data.items(): + for _, hyp in self._data.items(): if hyp.log_prob > threshold: ans.add(hyp) # shallow copy return ans @@ -216,6 +222,106 @@ class HypothesisList(object): return ", ".join(s) +def modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, +) -> List[int]: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + beam: + Beam size. + Returns: + Return the decoded result. + """ + + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + device = model.device + + T = encoder_out.size(1) + + B = HypothesisList() + B.add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + for t in range(T): + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) + # fmt: on + A = list(B) + B = HypothesisList() + + ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) + # ys_log_probs is of shape (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyp in A], + device=device, + dtype=torch.int64, + ) + # decoder_input is of shape (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) + + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, 1, -1 + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + # now logits is of shape (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) + + log_probs.add_(ys_log_probs) + + log_probs = log_probs.reshape(-1) + topk_log_probs, topk_indexes = log_probs.topk(beam) + + # topk_hyp_indexes are indexes into `A` + topk_hyp_indexes = topk_indexes // logits.size(-1) + topk_token_indexes = topk_indexes % logits.size(-1) + + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() + + for i in range(len(topk_hyp_indexes)): + hyp = A[topk_hyp_indexes[i]] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[i] + if new_token != blank_id: + new_ys.append(new_token) + new_log_prob = topk_log_probs[i] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B.add(new_hyp) + + best_hyp = B.get_most_probable(length_norm=True) + ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks + + return ys + + def beam_search( model: Transducer, encoder_out: torch.Tensor, @@ -246,7 +352,9 @@ def beam_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, + device=device, + dtype=torch.int64, ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -283,7 +391,9 @@ def beam_search( if cached_key not in decoder_cache: decoder_input = torch.tensor( - [y_star.ys[-context_size:]], device=device + [y_star.ys[-context_size:]], + device=device, + dtype=torch.int64, ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -297,7 +407,7 @@ def beam_search( current_encoder_out, decoder_out.unsqueeze(1) ) - # TODO(fangjun): Cache the blank posterior + # TODO(fangjun): Scale the blank posterior log_prob = logits.log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) @@ -309,7 +419,7 @@ def beam_search( # First, process the blank symbol skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() + new_y_star_log_prob = y_star.log_prob + skip_log_prob # ys[:] returns a copy of ys B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 9479d57a8..86ec6172f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -33,6 +33,15 @@ Usage: --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 """ @@ -46,14 +55,10 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import beam_search, greedy_search, modified_beam_search +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, @@ -104,6 +109,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -111,7 +117,8 @@ def get_parser(): "--beam-size", type=int, default=4, - help="Used only when --decoding-method is beam_search", + help="""Used only when --decoding-method is + beam_search or modified_beam_search""", ) parser.add_argument( @@ -125,78 +132,13 @@ def get_parser(): "--max-sym-per-frame", type=int, default=3, - help="Maximum number of symbols per frame", + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", ) return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for decoder - "embedding_dim": 512, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.vocab_size, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -258,6 +200,10 @@ def decode_one_batch( hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -391,11 +337,15 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": + if "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" @@ -469,8 +419,5 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 94987c39a..7d2a07817 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -39,7 +39,7 @@ you can do: --exp-dir ./pruned_transducer_stateless/exp \ --epoch 9999 \ --avg 1 \ - --max-duration 1 \ + --max-duration 100 \ --bpe-model data/lang_bpe_500/bpe.model """ @@ -49,15 +49,10 @@ from pathlib import Path import sentencepiece as spm import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import str2bool def get_parser(): @@ -117,71 +112,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for decoder - "embedding_dim": 512, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.vocab_size, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index 73c5aee5c..e6528b8d7 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -49,17 +49,10 @@ from typing import List import kaldifeat import sentencepiece as spm import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import beam_search, greedy_search, modified_beam_search from torch.nn.utils.rnn import pad_sequence - -from icefall.env import get_env_info -from icefall.utils import AttributeDict +from train import get_params, get_transducer_model def get_parser(): @@ -91,6 +84,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -104,11 +98,18 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, default=4, - help="Used only when --method is beam_search", + help="Used only when --method is beam_search and modified_beam_search", ) parser.add_argument( @@ -130,72 +131,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - # parameters for decoder - "embedding_dim": 512, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.vocab_size, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.vocab_size, - inner_dim=params.embedding_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -220,6 +155,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -278,10 +214,9 @@ def main(): feature_lengths = torch.tensor(feature_lengths, device=device) - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -303,6 +238,10 @@ def main(): hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError(f"Unsupported method: {params.method}") diff --git a/requirements.txt b/requirements.txt index 09d9ef69f..4eaa86a67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,3 @@ kaldialign sentencepiece>=0.1.96 tensorboard typeguard -optimized_transducer