This commit is contained in:
PingFeng Luo 2021-12-27 16:43:39 +08:00
commit 234307f33a
28 changed files with 255 additions and 172 deletions

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
name: run-pre-trained-tranducer-stateless name: run-pre-trained-trandsucer-stateless
on: on:
push: push:
@ -74,11 +74,11 @@ jobs:
mkdir tmp mkdir tmp
cd tmp cd tmp
git lfs install git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27
cd .. cd ..
tree tmp tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
- name: Run greedy search decoding - name: Run greedy search decoding
shell: bash shell: bash
@ -87,11 +87,11 @@ jobs:
cd egs/librispeech/ASR cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method greedy_search \ --method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding - name: Run beam search decoding
shell: bash shell: bash
@ -101,8 +101,8 @@ jobs:
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav

View File

@ -0,0 +1,109 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-transducer
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_transducer:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz
shell: bash
run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
- name: Run greedy search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./transducer/pretrained.py \
--method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./transducer/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav

View File

@ -71,7 +71,7 @@ The best WER with greedy search is:
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 3.16 | 7.71 | | WER | 3.07 | 7.51 |
We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
@ -84,7 +84,7 @@ The best WER using beam search with beam size 4 is:
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 2.92 | 7.37 | | WER | 2.83 | 7.19 |
Note: No auxiliary losses are used in the training and no LMs are used Note: No auxiliary losses are used in the training and no LMs are used
in the decoding. in the decoding.

View File

@ -2,7 +2,10 @@
### LibriSpeech BPE training results (Transducer) ### LibriSpeech BPE training results (Transducer)
#### 2021-12-22 #### Conformer encoder + embedding decoder
Using commit `TODO`.
Conformer encoder + non-current decoder. The decoder Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2). contains only an embedding layer and a Conv1d (with kernel size 2).
@ -10,12 +13,8 @@ The WERs are
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------| |---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.99 | 7.52 | --epoch 20, --avg 10, --max-duration 100 | | greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 |
| beam search (beam size 2) | 2.95 | 7.43 | | | beam search (beam size 4) | 2.83 | 7.19 | |
| beam search (beam size 3) | 2.94 | 7.37 | |
| beam search (beam size 4) | 2.92 | 7.37 | |
| beam search (beam size 5) | 2.93 | 7.38 | |
| beam search (beam size 8) | 2.92 | 7.38 | |
The training command for reproducing is given below: The training command for reproducing is given below:
@ -33,12 +32,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
``` ```
The tensorboard training log can be found at The tensorboard training log can be found at
<https://tensorboard.dev/experiment/PsJ3LgkEQfOmzedAlYfVeg/#scalars&_smoothingWeight=0> <https://tensorboard.dev/experiment/Mjx7MeTgR3Oyr1yBCwjozw/>
The decoding command is: The decoding command is:
``` ```
epoch=20 epoch=29
avg=10 avg=13
## greedy search ## greedy search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
@ -60,8 +59,8 @@ avg=10
``` ```
#### 2021-12-17 #### Conformer encoder + LSTM decoder
Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`. Using commit `8187d6236c2926500da5ee854f758e621df803cc`.
Conformer encoder + LSTM decoder. Conformer encoder + LSTM decoder.
@ -69,9 +68,9 @@ The best WER is
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 3.16 | 7.71 | | WER | 3.07 | 7.51 |
using `--epoch 26 --avg 12` with **greedy search**. using `--epoch 34 --avg 11` with **greedy search**.
The training command to reproduce the above WER is: The training command to reproduce the above WER is:
@ -80,19 +79,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \ ./transducer/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 35 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer/exp-lr-2.5-full \ --exp-dir transducer/exp-lr-2.5-full \
--full-libri 1 \ --full-libri 1 \
--max-duration 250 \ --max-duration 180 \
--lr-factor 2.5 --lr-factor 2.5
``` ```
The decoding command is: The decoding command is:
``` ```
epoch=26 epoch=34
avg=12 avg=11
./transducer/decode.py \ ./transducer/decode.py \
--epoch $epoch \ --epoch $epoch \
@ -102,7 +101,7 @@ avg=12
--max-duration 100 --max-duration 100
``` ```
You can find the tensorboard log at: <https://tensorboard.dev/experiment/PYIbeD6zRJez1ViXaRqqeg/> You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3xqTpyVmWi5FnWjrA>
### LibriSpeech BPE training results (Conformer-CTC) ### LibriSpeech BPE training results (Conformer-CTC)

View File

@ -111,7 +111,6 @@ def beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) sos = torch.tensor([blank_id], device=device).reshape(1, 1)
@ -192,7 +191,7 @@ def beam_search(
# Second, choose other labels # Second, choose other labels
for i, v in enumerate(log_prob.tolist()): for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id): if i == blank_id:
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v

View File

@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout, dropout=dropout,
normalize_before=normalize_before, normalize_before=normalize_before,
vgg_frontend=vgg_frontend, vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
) )
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
""" """
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.BatchNorm1d(channels) self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.activation(self.norm(x)) # x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -70,14 +70,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=26, default=34,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=12, default=11,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -129,10 +129,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params # decoder params
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"num_decoder_layers": 4, "num_decoder_layers": 2,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -151,7 +150,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -161,7 +159,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -399,9 +396,8 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -27,7 +27,6 @@ class Decoder(nn.Module):
vocab_size: int, vocab_size: int,
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
sos_id: int,
num_layers: int, num_layers: int,
hidden_dim: int, hidden_dim: int,
output_dim: int, output_dim: int,
@ -42,8 +41,6 @@ class Decoder(nn.Module):
Dimension of the input embedding. Dimension of the input embedding.
blank_id: blank_id:
The ID of the blank symbol. The ID of the blank symbol.
sos_id:
The ID of the SOS symbol.
num_layers: num_layers:
Number of LSTM layers. Number of LSTM layers.
hidden_dim: hidden_dim:
@ -71,7 +68,6 @@ class Decoder(nn.Module):
dropout=rnn_dropout, dropout=rnn_dropout,
) )
self.blank_id = blank_id self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim) self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward( def forward(

View File

@ -23,8 +23,8 @@ Usage:
./transducer/export.py \ ./transducer/export.py \
--exp-dir ./transducer/exp \ --exp-dir ./transducer/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--epoch 26 \ --epoch 34 \
--avg 12 --avg 11
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
@ -66,7 +66,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=26, default=34,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
@ -74,7 +74,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=12, default=11,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -119,10 +119,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params # decoder params
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"num_decoder_layers": 4, "num_decoder_layers": 2,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -140,7 +139,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -150,7 +148,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -197,9 +194,8 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -16,7 +16,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module): class Joiner(nn.Module):
@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C) # Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = F.relu(logit) logit = torch.tanh(logit)
output = self.output_linear(logit) output = self.output_linear(logit)

View File

@ -49,7 +49,7 @@ class Transducer(nn.Module):
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`. one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
@ -58,7 +58,6 @@ class Transducer(nn.Module):
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface) assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id") assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
@ -97,8 +96,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id sos_y = add_sos(y, sos_id=blank_id)
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

View File

@ -116,10 +116,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params # decoder params
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"num_decoder_layers": 4, "num_decoder_layers": 2,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -137,7 +136,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -147,7 +145,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -211,9 +208,8 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")

View File

@ -36,7 +36,6 @@ def test_conformer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
N = 3 N = 3
T = 100 T = 100

View File

@ -29,7 +29,6 @@ from decoder import Decoder
def test_decoder(): def test_decoder():
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
sos_id = 2
embedding_dim = 128 embedding_dim = 128
num_layers = 2 num_layers = 2
hidden_dim = 6 hidden_dim = 6
@ -41,7 +40,6 @@ def test_decoder():
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers, num_layers=num_layers,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=output_dim, output_dim=output_dim,

View File

@ -39,7 +39,6 @@ def test_transducer():
# decoder params # decoder params
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
sos_id = 2
embedding_dim = 128 embedding_dim = 128
num_layers = 2 num_layers = 2
@ -51,14 +50,12 @@ def test_transducer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
decoder = Decoder( decoder = Decoder(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers, num_layers=num_layers,
hidden_dim=output_dim, hidden_dim=output_dim,
output_dim=output_dim, output_dim=output_dim,

View File

@ -36,7 +36,6 @@ def test_transformer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
N = 3 N = 3
T = 100 T = 100

View File

@ -23,7 +23,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \ ./transducer/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 35 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer/exp \ --exp-dir transducer/exp \
--full-libri 1 \ --full-libri 1 \
@ -92,7 +92,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=30, default=35,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -171,15 +171,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model. - attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
""" """
params = AttributeDict( params = AttributeDict(
@ -201,13 +196,11 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params # decoder params
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"num_decoder_layers": 4, "num_decoder_layers": 2,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k "warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -227,7 +220,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -237,7 +229,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -573,9 +564,8 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -599,7 +589,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim, model_size=params.attention_dim,
factor=params.lr_factor, factor=params.lr_factor,
warm_step=params.warm_step, warm_step=params.warm_step,
weight_decay=params.weight_decay,
) )
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:

View File

@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm. If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend: vgg_frontend:
True to use vgg style frontend for subsampling. True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
""" """
super().__init__() super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features self.num_features = num_features
self.output_dim = output_dim self.output_dim = output_dim
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
""" """
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x = self.encoder_pos(x) x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)

View File

@ -22,13 +22,18 @@ import torch
from model import Transducer from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]:
""" """
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
encoder_out: encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
# Maximum symbols per utterance. # Maximum symbols per utterance.
max_sym_per_utt = 1000 max_sym_per_utt = 1000
# If at frame t, it decodes more than this number of symbols,
# it will move to the next step t+1
max_sym_per_frame = 3
# symbols per frame # symbols per frame
sym_per_frame = 0 sym_per_frame = 0
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt = 0 sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt: while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on # fmt: on
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
sym_per_utt += 1 sym_per_utt += 1
sym_per_frame += 1 sym_per_frame += 1
else:
if y == blank_id or sym_per_frame > max_sym_per_frame:
sym_per_frame = 0 sym_per_frame = 0
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks hyp = hyp[context_size:] # remove blanks

View File

@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout, dropout=dropout,
normalize_before=normalize_before, normalize_before=normalize_before,
vgg_frontend=vgg_frontend, vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
) )
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
""" """
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.BatchNorm1d(channels) self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.activation(self.norm(x)) # x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -70,14 +70,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=20, default=29,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=13,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -114,6 +114,20 @@ def get_parser():
help="Used only when --decoding-method is beam_search", help="Used only when --decoding-method is beam_search",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="Maximum number of symbols per frame",
)
return parser return parser
@ -129,9 +143,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -149,7 +160,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -237,7 +247,11 @@ def decode_one_batch(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i) hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
@ -381,6 +395,9 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search": if params.decoding_method == "beam_search":
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started") logging.info("Decoding started")

View File

@ -20,13 +20,14 @@ import torch.nn.functional as F
class Decoder(nn.Module): class Decoder(nn.Module):
"""This class implements the stateless decoder from the following paper: """This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction It removes the recurrent connection from the decoder, i.e., the prediction
network. network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
""" """

View File

@ -104,6 +104,14 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser return parser
@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder

View File

@ -16,7 +16,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module): class Joiner(nn.Module):
@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C) # Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = F.relu(logit) logit = torch.tanh(logit)
output = self.output_linear(logit) output = self.output_linear(logit)

View File

@ -110,6 +110,22 @@ def get_parser():
help="Used only when --method is beam_search", help="Used only when --method is beam_search",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser return parser
@ -126,9 +142,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -279,7 +291,11 @@ def main():
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on # fmt: on
if params.method == "greedy_search": if params.method == "greedy_search":
hyp = greedy_search(model=model, encoder_out=encoder_out_i) hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search": elif params.method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size

View File

@ -130,6 +130,14 @@ def get_parser():
help="The lr_factor for Noam optimizer", help="The lr_factor for Noam optimizer",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser return parser
@ -171,15 +179,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model. - attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
""" """
params = AttributeDict( params = AttributeDict(
@ -201,11 +204,7 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k "warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -225,7 +224,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -568,7 +566,7 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
@ -593,7 +591,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim, model_size=params.attention_dim,
factor=params.lr_factor, factor=params.lr_factor,
warm_step=params.warm_step, warm_step=params.warm_step,
weight_decay=params.weight_decay,
) )
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:

View File

@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm. If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend: vgg_frontend:
True to use vgg style frontend for subsampling. True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
""" """
super().__init__() super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features self.num_features = num_features
self.output_dim = output_dim self.output_dim = output_dim
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
""" """
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x = self.encoder_pos(x) x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)