This commit is contained in:
PingFeng Luo 2021-12-27 16:43:39 +08:00
commit 234307f33a
28 changed files with 255 additions and 172 deletions

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-tranducer-stateless
name: run-pre-trained-trandsucer-stateless
on:
push:
@ -74,11 +74,11 @@ jobs:
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
- name: Run greedy search decoding
shell: bash
@ -87,11 +87,11 @@ jobs:
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
@ -101,8 +101,8 @@ jobs:
./transducer_stateless/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav

View File

@ -0,0 +1,109 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-transducer
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_transducer:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz
shell: bash
run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
- name: Run greedy search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./transducer/pretrained.py \
--method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./transducer/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav

View File

@ -71,7 +71,7 @@ The best WER with greedy search is:
| | test-clean | test-other |
|-----|------------|------------|
| WER | 3.16 | 7.71 |
| WER | 3.07 | 7.51 |
We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
@ -84,7 +84,7 @@ The best WER using beam search with beam size 4 is:
| | test-clean | test-other |
|-----|------------|------------|
| WER | 2.92 | 7.37 |
| WER | 2.83 | 7.19 |
Note: No auxiliary losses are used in the training and no LMs are used
in the decoding.

View File

@ -2,7 +2,10 @@
### LibriSpeech BPE training results (Transducer)
#### 2021-12-22
#### Conformer encoder + embedding decoder
Using commit `TODO`.
Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2).
@ -10,12 +13,8 @@ The WERs are
| | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.99 | 7.52 | --epoch 20, --avg 10, --max-duration 100 |
| beam search (beam size 2) | 2.95 | 7.43 | |
| beam search (beam size 3) | 2.94 | 7.37 | |
| beam search (beam size 4) | 2.92 | 7.37 | |
| beam search (beam size 5) | 2.93 | 7.38 | |
| beam search (beam size 8) | 2.92 | 7.38 | |
| greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 |
| beam search (beam size 4) | 2.83 | 7.19 | |
The training command for reproducing is given below:
@ -33,12 +32,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
```
The tensorboard training log can be found at
<https://tensorboard.dev/experiment/PsJ3LgkEQfOmzedAlYfVeg/#scalars&_smoothingWeight=0>
<https://tensorboard.dev/experiment/Mjx7MeTgR3Oyr1yBCwjozw/>
The decoding command is:
```
epoch=20
avg=10
epoch=29
avg=13
## greedy search
./transducer_stateless/decode.py \
@ -60,8 +59,8 @@ avg=10
```
#### 2021-12-17
Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`.
#### Conformer encoder + LSTM decoder
Using commit `8187d6236c2926500da5ee854f758e621df803cc`.
Conformer encoder + LSTM decoder.
@ -69,9 +68,9 @@ The best WER is
| | test-clean | test-other |
|-----|------------|------------|
| WER | 3.16 | 7.71 |
| WER | 3.07 | 7.51 |
using `--epoch 26 --avg 12` with **greedy search**.
using `--epoch 34 --avg 11` with **greedy search**.
The training command to reproduce the above WER is:
@ -80,19 +79,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \
--world-size 4 \
--num-epochs 30 \
--num-epochs 35 \
--start-epoch 0 \
--exp-dir transducer/exp-lr-2.5-full \
--full-libri 1 \
--max-duration 250 \
--max-duration 180 \
--lr-factor 2.5
```
The decoding command is:
```
epoch=26
avg=12
epoch=34
avg=11
./transducer/decode.py \
--epoch $epoch \
@ -102,7 +101,7 @@ avg=12
--max-duration 100
```
You can find the tensorboard log at: <https://tensorboard.dev/experiment/PYIbeD6zRJez1ViXaRqqeg/>
You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3xqTpyVmWi5FnWjrA>
### LibriSpeech BPE training results (Conformer-CTC)

View File

@ -111,7 +111,6 @@ def beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
@ -192,7 +191,7 @@ def beam_search(
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v

View File

@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -70,14 +70,14 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=26,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=12,
default=11,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
@ -129,10 +129,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"num_decoder_layers": 2,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
@ -151,7 +150,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -161,7 +159,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -399,9 +396,8 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -27,7 +27,6 @@ class Decoder(nn.Module):
vocab_size: int,
embedding_dim: int,
blank_id: int,
sos_id: int,
num_layers: int,
hidden_dim: int,
output_dim: int,
@ -42,8 +41,6 @@ class Decoder(nn.Module):
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
sos_id:
The ID of the SOS symbol.
num_layers:
Number of LSTM layers.
hidden_dim:
@ -71,7 +68,6 @@ class Decoder(nn.Module):
dropout=rnn_dropout,
)
self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward(

View File

@ -23,8 +23,8 @@ Usage:
./transducer/export.py \
--exp-dir ./transducer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 26 \
--avg 12
--epoch 34 \
--avg 11
It will generate a file exp_dir/pretrained.pt
@ -66,7 +66,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=26,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
@ -74,7 +74,7 @@ def get_parser():
parser.add_argument(
"--avg",
type=int,
default=12,
default=11,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
@ -119,10 +119,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"num_decoder_layers": 2,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
@ -140,7 +139,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -150,7 +148,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -197,9 +194,8 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -16,7 +16,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
logit = torch.tanh(logit)
output = self.output_linear(logit)

View File

@ -49,7 +49,7 @@ class Transducer(nn.Module):
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`.
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
@ -58,7 +58,6 @@ class Transducer(nn.Module):
super().__init__()
assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder
self.decoder = decoder
@ -97,8 +96,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id
sos_y = add_sos(y, sos_id=sos_id)
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

View File

@ -116,10 +116,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"num_decoder_layers": 2,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
@ -137,7 +136,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -147,7 +145,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -211,9 +208,8 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")

View File

@ -36,7 +36,6 @@ def test_conformer():
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
N = 3
T = 100

View File

@ -29,7 +29,6 @@ from decoder import Decoder
def test_decoder():
vocab_size = 3
blank_id = 0
sos_id = 2
embedding_dim = 128
num_layers = 2
hidden_dim = 6
@ -41,7 +40,6 @@ def test_decoder():
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers,
hidden_dim=hidden_dim,
output_dim=output_dim,

View File

@ -39,7 +39,6 @@ def test_transducer():
# decoder params
vocab_size = 3
blank_id = 0
sos_id = 2
embedding_dim = 128
num_layers = 2
@ -51,14 +50,12 @@ def test_transducer():
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
decoder = Decoder(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers,
hidden_dim=output_dim,
output_dim=output_dim,

View File

@ -36,7 +36,6 @@ def test_transformer():
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
use_feat_batchnorm=True,
)
N = 3
T = 100

View File

@ -23,7 +23,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \
--world-size 4 \
--num-epochs 30 \
--num-epochs 35 \
--start-epoch 0 \
--exp-dir transducer/exp \
--full-libri 1 \
@ -92,7 +92,7 @@ def get_parser():
parser.add_argument(
"--num-epochs",
type=int,
default=30,
default=35,
help="Number of epochs to train.",
)
@ -171,15 +171,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
@ -201,13 +196,11 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"num_decoder_layers": 2,
"decoder_hidden_dim": 512,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
@ -227,7 +220,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -237,7 +229,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
@ -573,9 +564,8 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -599,7 +589,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints:

View File

@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.output_dim = output_dim
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)

View File

@ -22,13 +22,18 @@ import torch
from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
Returns:
Return the decoded result.
"""
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# Maximum symbols per utterance.
max_sym_per_utt = 1000
# If at frame t, it decodes more than this number of symbols,
# it will move to the next step t+1
max_sym_per_frame = 3
# symbols per frame
sym_per_frame = 0
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt += 1
sym_per_frame += 1
if y == blank_id or sym_per_frame > max_sym_per_frame:
else:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks

View File

@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -70,14 +70,14 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=20,
default=29,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=10,
default=13,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
@ -114,6 +114,20 @@ def get_parser():
help="Used only when --decoding-method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser
@ -129,9 +143,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@ -149,7 +160,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -237,7 +247,11 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
@ -381,6 +395,9 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")

View File

@ -20,13 +20,14 @@ import torch.nn.functional as F
class Decoder(nn.Module):
"""This class implements the stateless decoder from the following paper:
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network.
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""

View File

@ -104,6 +104,14 @@ def get_parser():
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder

View File

@ -16,7 +16,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
logit = torch.tanh(logit)
output = self.output_linear(logit)

View File

@ -110,6 +110,22 @@ def get_parser():
help="Used only when --method is beam_search",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
@ -126,9 +142,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
}
)
@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -279,7 +291,11 @@ def main():
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size

View File

@ -130,6 +130,14 @@ def get_parser():
help="The lr_factor for Noam optimizer",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser
@ -171,15 +179,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
@ -201,11 +204,7 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
@ -225,7 +224,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -568,7 +566,7 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
@ -593,7 +591,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints:

View File

@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.output_dim = output_dim
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)