Add modified beam search for pruned rnn-t. (#248)

* Add modified beam search for pruned rnn-t.

* Fix style issues.

* Update RESULTS.md.

* Fix typos.

* Minor fixes.

* Test the pre-trained model using GitHub actions.

* Let the user install optimized_transducer on her own.

* Fix errors in GitHub CI.
This commit is contained in:
Fangjun Kuang 2022-03-12 16:16:55 +08:00 committed by GitHub
parent 2f4e71f433
commit bb7f6ed6b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 439 additions and 254 deletions

View File

@ -0,0 +1,157 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-librispeech-2022-03-12
# stateless transducer + k2 pruned rnnt-loss
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_librispeech_2022_03_12:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz
shell: bash
run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12/test_wavs/*.wav
- name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 1 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
dir=./tmp/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12
cd egs/librispeech/ASR
./pruned_transducer_stateless/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint $dir/exp/pretrained.pt \
--bpe-model $dir/data/lang_bpe_500/bpe.model \
$dir/test_wavs/1089-134686-0001.wav \
$dir/test_wavs/1221-135766-0001.wav \
$dir/test_wavs/1221-135766-0002.wav

View File

@ -84,7 +84,7 @@ The best WER using modified beam search with beam size 4 is:
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 2.61 | 6.46 | | WER | 2.56 | 6.27 |
Note: No auxiliary losses are used in the training and no LMs are used Note: No auxiliary losses are used in the training and no LMs are used
in the decoding. in the decoding.

View File

@ -15,6 +15,7 @@ The following table lists the differences among them.
| `transducer_stateless` | Conformer | Embedding + Conv1d | | | `transducer_stateless` | Conformer | Embedding + Conv1d | |
| `transducer_lstm` | LSTM | LSTM | | | `transducer_lstm` | LSTM | LSTM | |
| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | | `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).

View File

@ -2,12 +2,111 @@
### LibriSpeech BPE training results (Pruned Transducer) ### LibriSpeech BPE training results (Pruned Transducer)
#### Conformer encoder + embedding decoder
Conformer encoder + non-current decoder. The decoder Conformer encoder + non-current decoder. The decoder
contains only an embedding layer, a Conv1d (with kernel size 2) and a linear contains only an embedding layer, a Conv1d (with kernel size 2) and a linear
layer (to transform tensor dim). layer (to transform tensor dim).
#### 2022-03-12
[pruned_transducer_stateless](./pruned_transducer_stateless)
Using commit `1603744469d167d848e074f2ea98c587153205fa`.
See <https://github.com/k2-fsa/icefall/pull/248>
The WERs are:
| | test-clean | test-other | comment |
|-------------------------------------|------------|------------|------------------------------------------|
| greedy search (max sym per frame 1) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
| greedy search (max sym per frame 2) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
| greedy search (max sym per frame 3) | 2.62 | 6.37 | --epoch 42, --avg 11, --max-duration 100 |
| modified beam search (beam size 4) | 2.56 | 6.27 | --epoch 42, --avg 11, --max-duration 100 |
| beam search (beam size 4) | 2.57 | 6.27 | --epoch 42, --avg 11, --max-duration 100 |
The decoding time for `test-clean` and `test-other` is given below:
(A V100 GPU with 32 GB RAM is used for decoding. Note: Not all GPU RAM is used during decoding.)
| decoding method | test-clean (seconds) | test-other (seconds)|
|---|---:|---:|
| greedy search (--max-sym-per-frame=1) | 160 | 159 |
| greedy search (--max-sym-per-frame=2) | 184 | 177 |
| greedy search (--max-sym-per-frame=3) | 210 | 213 |
| modified beam search (--beam-size 4)| 273 | 269 |
|beam search (--beam-size 4) | 2741 | 2221 |
We recommend you to use `modified_beam_search`.
Training command:
```bash
cd egs/librispeech/ASR/
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
. path.sh
./pruned_transducer_stateless/train.py \
--world-size 8 \
--num-epochs 60 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300 \
--prune-range 5 \
--lr-factor 5 \
--lm-scale 0.25
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/WKRFY5fYSzaVBHahenpNlA/>
The command for decoding is:
```bash
epoch=42
avg=11
sym=1
# greedy search
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method greedy_search \
--beam-size 4 \
--max-sym-per-frame $sym
# modified beam search
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
# beam search
# (not recommended)
./pruned_transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
```
You can find a pre-trained model, decoding logs, and decoding results at
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless-2022-03-12>
#### 2022-02-18
[pruned_transducer_stateless](./pruned_transducer_stateless)
The WERs are The WERs are
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
@ -62,7 +161,7 @@ See
##### 2022-03-01 ##### 2022-03-01
Using commit `fill in it after merging`. Using commit `2332ba312d7ce72f08c7bac1e3312f7e3dd722dc`.
It uses [GigaSpeech](https://github.com/SpeechColab/GigaSpeech) It uses [GigaSpeech](https://github.com/SpeechColab/GigaSpeech)
as extra training data. 20% of the time it selects a batch from L subset of as extra training data. 20% of the time it selects a batch from L subset of
@ -129,6 +228,9 @@ sym=1
--beam-size 4 --beam-size 4
``` ```
You can find a pretrained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01>
##### 2022-02-07 ##### 2022-02-07

View File

@ -17,7 +17,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
import torch import torch
from model import Transducer from model import Transducer
@ -48,7 +47,7 @@ def greedy_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device [blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -103,8 +102,9 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`. # Newly predicted tokens are appended to `ys`.
ys: List[int] ys: List[int]
# The log prob of ys # The log prob of ys.
log_prob: float # It contains only one entry.
log_prob: torch.Tensor
@property @property
def key(self) -> str: def key(self) -> str:
@ -113,7 +113,7 @@ class Hypothesis:
class HypothesisList(object): class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None): def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
""" """
Args: Args:
data: data:
@ -125,10 +125,10 @@ class HypothesisList(object):
self._data = data self._data = data
@property @property
def data(self): def data(self) -> Dict[str, Hypothesis]:
return self._data return self._data
def add(self, hyp: Hypothesis): def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`. """Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using If `hyp` already exists in `self`, its probability is updated using
@ -140,8 +140,10 @@ class HypothesisList(object):
""" """
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] old_hyp = self._data[key] # shallow copy
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else: else:
self._data[key] = hyp self._data[key] = hyp
@ -153,7 +155,8 @@ class HypothesisList(object):
length_norm: length_norm:
If True, the `log_prob` of a hypothesis is normalized by the If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it. number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
""" """
if length_norm: if length_norm:
return max( return max(
@ -165,6 +168,9 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None: def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis. """Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args: Args:
hyp: hyp:
The hypothesis to be removed from `self`. The hypothesis to be removed from `self`.
@ -175,7 +181,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist" assert key in self, f"{key} does not exist"
del self._data[key] del self._data[key]
def filter(self, threshold: float) -> "HypothesisList": def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold. """Remove all Hypotheses whose log_prob is less than threshold.
Caution: Caution:
@ -183,10 +189,10 @@ class HypothesisList(object):
Returns: Returns:
Return a new HypothesisList containing all hypotheses from `self` Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`. with `log_prob` being greater than the given `threshold`.
""" """
ans = HypothesisList() ans = HypothesisList()
for key, hyp in self._data.items(): for _, hyp in self._data.items():
if hyp.log_prob > threshold: if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy ans.add(hyp) # shallow copy
return ans return ans
@ -216,6 +222,106 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# current_encoder_out is of shape (1, 1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
dtype=torch.int64,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, 1, -1
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
# now logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search( def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -246,7 +352,9 @@ def beam_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device [blank_id] * context_size,
device=device,
dtype=torch.int64,
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -283,7 +391,9 @@ def beam_search(
if cached_key not in decoder_cache: if cached_key not in decoder_cache:
decoder_input = torch.tensor( decoder_input = torch.tensor(
[y_star.ys[-context_size:]], device=device [y_star.ys[-context_size:]],
device=device,
dtype=torch.int64,
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -297,7 +407,7 @@ def beam_search(
current_encoder_out, decoder_out.unsqueeze(1) current_encoder_out, decoder_out.unsqueeze(1)
) )
# TODO(fangjun): Cache the blank posterior # TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)
@ -309,7 +419,7 @@ def beam_search(
# First, process the blank symbol # First, process the blank symbol
skip_log_prob = log_prob[blank_id] skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))

View File

@ -33,6 +33,15 @@ Usage:
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
""" """
@ -46,14 +55,10 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from train import get_params, get_transducer_model
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -104,6 +109,7 @@ def get_parser():
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
""", """,
) )
@ -111,7 +117,8 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --decoding-method is beam_search", help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -125,78 +132,13 @@ def get_parser():
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=3,
help="Maximum number of symbols per frame", help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
) )
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -258,6 +200,10 @@ def decode_one_batch(
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -391,11 +337,15 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search") assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search": if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
@ -469,8 +419,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -39,7 +39,7 @@ you can do:
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 1 \ --max-duration 100 \
--bpe-model data/lang_bpe_500/bpe.model --bpe-model data/lang_bpe_500/bpe.model
""" """
@ -49,15 +49,10 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn from train import get_params, get_transducer_model
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.utils import str2bool
from icefall.utils import AttributeDict, str2bool
def get_parser(): def get_parser():
@ -117,71 +112,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)

View File

@ -49,17 +49,10 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -91,6 +84,7 @@ def get_parser():
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
""", """,
) )
@ -104,11 +98,18 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search", help="Used only when --method is beam_search and modified_beam_search",
) )
parser.add_argument( parser.add_argument(
@ -130,72 +131,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -220,6 +155,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -278,7 +214,6 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths x=features, x_lens=feature_lengths
) )
@ -303,6 +238,10 @@ def main():
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")

View File

@ -3,4 +3,3 @@ kaldialign
sentencepiece>=0.1.96 sentencepiece>=0.1.96
tensorboard tensorboard
typeguard typeguard
optimized_transducer