Increase the size of the context in the RNN-T decoder. (#153)

This commit is contained in:
Fangjun Kuang 2021-12-23 07:55:02 +08:00 committed by GitHub
parent cb04c8a750
commit fb6a57e9e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1101 additions and 129 deletions

View File

@ -0,0 +1,108 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-tranducer-stateless
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_transducer_stateless:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz
shell: bash
run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav
- name: Run greedy search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav

View File

@ -30,7 +30,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"] torch: ["1.10.0"]
torchaudio: ["0.10.0"] torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"] k2-version: ["1.9.dev20211101"]

View File

@ -32,7 +32,7 @@ jobs:
# os: [ubuntu-18.04, macos-10.15] # os: [ubuntu-18.04, macos-10.15]
# disable macOS test for now. # disable macOS test for now.
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.7, 3.8]
torch: ["1.8.0", "1.10.0"] torch: ["1.8.0", "1.10.0"]
torchaudio: ["0.8.0", "0.10.0"] torchaudio: ["0.8.0", "0.10.0"]
k2-version: ["1.9.dev20211101"] k2-version: ["1.9.dev20211101"]
@ -106,6 +106,12 @@ jobs:
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer cd ../transducer
pytest -v -s pytest -v -s
cd ../transducer_stateless
pytest -v -s
cd ../transducer_lstm
pytest -v -s
fi fi
- name: Run tests - name: Run tests
@ -125,4 +131,10 @@ jobs:
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
cd ../transducer cd ../transducer
pytest -v -s pytest -v -s
cd ../transducer_stateless
pytest -v -s
cd ../transducer_lstm
pytest -v -s
fi fi

View File

@ -34,11 +34,12 @@ We do provide a Colab notebook for this recipe.
### LibriSpeech ### LibriSpeech
We provide 3 models for this recipe: We provide 4 models for this recipe:
- [conformer CTC model][LibriSpeech_conformer_ctc] - [conformer CTC model][LibriSpeech_conformer_ctc]
- [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc] - [TDNN LSTM CTC model][LibriSpeech_tdnn_lstm_ctc]
- [RNN-T Conformer model][LibriSpeech_transducer] - [Transducer: Conformer encoder + LSTM decoder][LibriSpeech_transducer]
- [Transducer: Conformer encoder + Embedding decoder][LibriSpeech_transducer_stateless]
#### Conformer CTC Model #### Conformer CTC Model
@ -62,9 +63,9 @@ The WER for this model is:
We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing) We provide a Colab notebook to run a pre-trained TDNN LSTM CTC model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kNmDXNMwREi0rZGAOIAOJo93REBuOTcd?usp=sharing)
#### RNN-T Conformer model #### Transducer: Conformer encoder + LSTM decoder
Using Conformer as encoder. Using Conformer as encoder and LSTM as decoder.
The best WER with greedy search is: The best WER with greedy search is:
@ -74,6 +75,21 @@ The best WER with greedy search is:
We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
#### Transducer: Conformer encoder + Embedding decoder
Using Conformer as encoder. The decoder consists of 1 embedding layer
and 1 convolutional layer.
The best WER using beam search with beam size 4 is:
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.92 | 7.37 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Lm37sNajIpkV4HTzMDF7sn9l0JpfmekN?usp=sharing)
### Aishell ### Aishell
@ -143,6 +159,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc [LibriSpeech_tdnn_lstm_ctc]: egs/librispeech/ASR/tdnn_lstm_ctc
[LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc [LibriSpeech_conformer_ctc]: egs/librispeech/ASR/conformer_ctc
[LibriSpeech_transducer]: egs/librispeech/ASR/transducer [LibriSpeech_transducer]: egs/librispeech/ASR/transducer
[LibriSpeech_transducer_stateless]: egs/librispeech/ASR/transducer_stateless
[Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc [Aishell_tdnn_lstm_ctc]: egs/aishell/ASR/tdnn_lstm_ctc
[Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc [Aishell_conformer_ctc]: egs/aishell/ASR/conformer_ctc
[TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc [TIMIT_tdnn_lstm_ctc]: egs/timit/ASR/tdnn_lstm_ctc

View File

@ -1,3 +1,20 @@
# Introduction
Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech.html> Please refer to <https://icefall.readthedocs.io/en/latest/recipes/librispeech.html>
for how to run models in this recipe. for how to run models in this recipe.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder |
|------------------------|-----------|--------------------|
| `transducer` | Conformer | LSTM |
| `transducer_stateless` | Conformer | Embedding + Conv1d |
| `transducer_lstm ` | LSTM | LSTM |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

View File

@ -1,10 +1,69 @@
## Results ## Results
### LibriSpeech BPE training results (RNN-T) ### LibriSpeech BPE training results (Transducer)
#### 2021-12-22
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2).
The WERs are
| | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.99 | 7.52 | --epoch 20, --avg 10, --max-duration 100 |
| beam search (beam size 2) | 2.95 | 7.43 | |
| beam search (beam size 3) | 2.94 | 7.37 | |
| beam search (beam size 4) | 2.92 | 7.37 | |
| beam search (beam size 5) | 2.93 | 7.38 | |
| beam search (beam size 8) | 2.92 | 7.38 | |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp-full \
--full-libri 1 \
--max-duration 250 \
--lr-factor 3
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/PsJ3LgkEQfOmzedAlYfVeg/#scalars&_smoothingWeight=0>
The decoding command is:
```
epoch=20
avg=10
## greedy search
./transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless/exp-full \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100
## beam search
./transducer_stateless/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir transducer_stateless/exp-full \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
```
#### 2021-12-17 #### 2021-12-17
Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`.
RNN-T + Conformer encoder Conformer encoder + LSTM decoder.
The best WER is The best WER is
@ -12,7 +71,7 @@ The best WER is
|-----|------------|------------| |-----|------------|------------|
| WER | 3.16 | 7.71 | | WER | 3.16 | 7.71 |
using `--epoch 26 --avg 12` during decoding with greedy search. using `--epoch 26 --avg 12` with **greedy search**.
The training command to reproduce the above WER is: The training command to reproduce the above WER is:

View File

@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
class Transducer(nn.Module): class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
@ -115,6 +110,11 @@ class Transducer(nn.Module):
# Note: y does not start with SOS # Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
loss = torchaudio.functional.rnnt_loss( loss = torchaudio.functional.rnnt_loss(
logits=logits, logits=logits,
targets=y_padded, targets=y_padded,

View File

@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
class Transducer(nn.Module): class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
@ -115,6 +110,11 @@ class Transducer(nn.Module):
# Note: y does not start with SOS # Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
loss = torchaudio.functional.rnnt_loss( loss = torchaudio.functional.rnnt_loss(
logits=logits, logits=logits,
targets=y_padded, targets=y_padded,

View File

@ -15,8 +15,9 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
import numpy as np
import torch import torch
from model import Transducer from model import Transducer
@ -35,21 +36,35 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) decoder_input = torch.tensor(
decoder_out = model.decoder(sos) [blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0
hyp = [] hyp = [blank_id] * context_size
sym_per_frame = 0
sym_per_utt = 0
# Maximum symbols per utterance.
max_sym_per_utt = 1000 max_sym_per_utt = 1000
# If at frame t, it decodes more than this number of symbols,
# it will move to the next step t+1
max_sym_per_frame = 3 max_sym_per_frame = 3
# symbols per frame
sym_per_frame = 0
# symbols per utterance decoded so far
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
@ -57,14 +72,14 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1) y = logits.argmax().item()
# log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id: if y != blank_id:
hyp.append(y.item()) hyp.append(y)
y = y.reshape(1, 1) decoder_input = torch.tensor(
decoder_out = model.decoder(y) [hyp[-context_size:]], device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
sym_per_utt += 1 sym_per_utt += 1
sym_per_frame += 1 sym_per_frame += 1
@ -72,24 +87,135 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
if y == blank_id or sym_per_frame > max_sym_per_frame: if y == blank_id or sym_per_frame > max_sym_per_frame:
sym_per_frame = 0 sym_per_frame = 0
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks
return hyp return hyp
@dataclass @dataclass
class Hypothesis: class Hypothesis:
ys: List[int] # the predicted sequences so far # The predicted tokens so far.
log_prob: float # The log prob of ys # Newly predicted tokens are appended to `ys`.
ys: List[int]
# Optional decoder state. We assume it is LSTM for now, # The log prob of ys
# so the state is a tuple (h, c) log_prob: float
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None):
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self):
return self._data
# def add(self, ys: List[int], log_prob: float):
def add(self, hyp: Hypothesis):
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key]
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
"""
if length_norm:
return max(
self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)
)
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: float) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
that have `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for key, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def beam_search( def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 5, beam: int = 4,
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
@ -111,110 +237,98 @@ def beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id context_size = model.decoder.context_size
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) decoder_input = torch.tensor(
decoder_out, (h, c) = model.decoder(sos) [blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[ B = HypothesisList()
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
] = {}
while t < T and u < max_u: max_sym_per_utt = 20000
sym_per_utt = 0
decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on # fmt: on
A = B A = B
B = [] B = HypothesisList()
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
while u < max_u: joint_cache: Dict[str, torch.Tensor] = {}
y_star = max(A, key=lambda hyp: hyp.log_prob)
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True:
y_star = A.get_most_probable()
A.remove(y_star) A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used cached_key = y_star.key
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
if cached_key not in cache: if cached_key not in decoder_cache:
decoder_input = torch.tensor( decoder_input = torch.tensor(
[y_star.ys[-1]], device=device [y_star.ys[-context_size:]], device=device
).reshape(1, 1) ).reshape(1, context_size)
decoder_out, decoder_state = model.decoder( decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_input, decoder_cache[cached_key] = decoder_out
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
else: else:
decoder_out, decoder_state = cache[cached_key] decoder_out = decoder_cache[cached_key]
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(current_encoder_out, decoder_out)
# TODO(fangjun): Ccale the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze() log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,) # Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]
# If we choose blank here, add the new hypothesis to B. # First, process the blank symbol
# Otherwise, add the new hypothesis to A
# First, choose blank
skip_log_prob = log_prob[blank_id] skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
new_y_star = Hypothesis( B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
# Second, choose other labels # Second, process other non-blank labels
for i, v in enumerate(log_prob.tolist()): values, indices = log_prob.topk(beam + 1)
if i in (blank_id, sos_id): for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id:
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis( A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
ys=new_ys,
log_prob=new_log_prob, # Check whether B contains more than "beam" elements more probable
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
# than the most probable in A # than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob) A_most_probable = A.get_most_probable()
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob], kept_B = B.filter(A_most_probable.log_prob)
key=lambda hyp: hyp.log_prob,
reverse=True, if len(kept_B) >= beam:
) B = kept_B.topk(beam)
if len(B) >= beam:
B = B[:beam]
break break
t += 1 t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys return ys

View File

@ -24,15 +24,15 @@ Usage:
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search
(2) beam search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch 14 \ --epoch 14 \
--avg 7 \ --avg 7 \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 8 --beam-size 4
""" """
@ -70,14 +70,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=77, default=20,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=55, default=10,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
default=5, default=4,
help="Used only when --decoding-method is beam_search", help="Used only when --decoding-method is beam_search",
) )
@ -130,7 +130,8 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
# decoder params # parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -158,6 +159,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size,
) )
return decoder return decoder
@ -392,9 +394,8 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -16,6 +16,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module): class Decoder(nn.Module):
@ -35,6 +36,7 @@ class Decoder(nn.Module):
vocab_size: int, vocab_size: int,
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
context_size: int,
): ):
""" """
Args: Args:
@ -44,6 +46,9 @@ class Decoder(nn.Module):
Dimension of the input embedding. Dimension of the input embedding.
blank_id: blank_id:
The ID of the blank symbol. The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
""" """
super().__init__() super().__init__()
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
@ -53,13 +58,40 @@ class Decoder(nn.Module):
) )
self.blank_id = blank_id self.blank_id = blank_id
def forward(self, y: torch.Tensor) -> torch.Tensor: assert context_size >= 1, context_size
self.context_size = context_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
Args: Args:
y: y:
A 2-D tensor of shape (N, U) with blank prepended. A 2-D tensor of shape (N, U) with blank prepended.
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
embeding_out = self.embedding(y) embeding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
if need_pad is True:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out return embeding_out

View File

@ -0,0 +1,244 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `transducer_stateless/decode.py`, you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./transducer_stateless/decode.py \
--exp-dir ./transducer_stateless/exp \
--epoch 9999 \
--avg 1 \
--max-duration 1 \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=20,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.eval()
model.to("cpu")
model.eval()
if params.jit:
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -27,11 +27,6 @@ from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
class Transducer(nn.Module): class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
@ -113,6 +108,11 @@ class Transducer(nn.Module):
# Note: y does not start with SOS # Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
loss = torchaudio.functional.rnnt_loss( loss = torchaudio.functional.rnnt_loss(
logits=logits, logits=logits,
targets=y_padded, targets=y_padded,

View File

@ -0,0 +1,307 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
(1) beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
Note: ./transducer_stateless/exp/pretrained.pt is generated by
./transducer_stateless/export.py
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,58 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_stateless/test_decoder.py
"""
import torch
from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
embedding_dim = 128
context_size = 4
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
context_size=context_size,
)
N = 100
U = 20
x = torch.randint(low=0, high=vocab_size, size=(N, U))
y = decoder(x)
assert y.shape == (N, U, embedding_dim)
# for inference
x = torch.randint(low=0, high=vocab_size, size=(N, context_size))
y = decoder(x, need_pad=False)
assert y.shape == (N, 1, embedding_dim)
def main():
test_decoder()
if __name__ == "__main__":
main()

View File

@ -92,7 +92,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=78, default=30,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -202,6 +202,8 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6, "weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k "warm_step": 80000, # For the 100h subset, use 8k
@ -233,6 +235,7 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size,
) )
return decoder return decoder