Minor fixes to the RNN-T Conformer model (#152)

* Disable weight decay.

* Remove input feature batchnorm..

* Replace BatchNorm in the Conformer model with LayerNorm.

* Use tanh in the joint network.

* Remove sos ID.

* Reduce the number of decoder layers from 4 to 2.

* Minor fixes.

* Fix typos.
This commit is contained in:
Fangjun Kuang 2021-12-23 13:54:25 +08:00 committed by GitHub
parent fb6a57e9e0
commit 5b6699a835
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 147 additions and 86 deletions

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
name: run-pre-trained-tranducer-stateless name: run-pre-trained-trandsucer-stateless
on: on:
push: push:

View File

@ -0,0 +1,109 @@
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: run-pre-trained-transducer
on:
push:
branches:
- master
pull_request:
types: [labeled]
jobs:
run_pre_trained_transducer:
if: github.event.label.name == 'ready' || github.event_name == 'push'
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-18.04]
python-version: [3.7, 3.8, 3.9]
torch: ["1.10.0"]
torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
python3 -m pip install kaldifeat
# We are in ./icefall and there is a file: requirements.txt in it
pip install -r requirements.txt
- name: Install graphviz
shell: bash
run: |
python3 -m pip install -qq graphviz
sudo apt-get -qq install graphviz
- name: Download pre-trained model
shell: bash
run: |
sudo apt-get -qq install git-lfs tree sox
cd egs/librispeech/ASR
mkdir tmp
cd tmp
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
cd ..
tree tmp
soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
- name: Run greedy search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./transducer/pretrained.py \
--method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./transducer/pretrained.py \
--method beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav

View File

@ -71,7 +71,7 @@ The best WER with greedy search is:
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 3.16 | 7.71 | | WER | 3.07 | 7.51 |
We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing) We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)

View File

@ -2,7 +2,10 @@
### LibriSpeech BPE training results (Transducer) ### LibriSpeech BPE training results (Transducer)
#### 2021-12-22 #### Conformer encoder + embedding decoder
Using commit `fb6a57e9e01dd8aae2af2a6b4568daad8bc8ab32`.
Conformer encoder + non-current decoder. The decoder Conformer encoder + non-current decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2). contains only an embedding layer and a Conv1d (with kernel size 2).
@ -60,8 +63,8 @@ avg=10
``` ```
#### 2021-12-17 #### Conformer encoder + LSTM decoder
Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`. Using commit `TODO`.
Conformer encoder + LSTM decoder. Conformer encoder + LSTM decoder.
@ -69,9 +72,9 @@ The best WER is
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 3.16 | 7.71 | | WER | 3.07 | 7.51 |
using `--epoch 26 --avg 12` with **greedy search**. using `--epoch 34 --avg 11` with **greedy search**.
The training command to reproduce the above WER is: The training command to reproduce the above WER is:
@ -80,19 +83,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \ ./transducer/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 35 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer/exp-lr-2.5-full \ --exp-dir transducer/exp-lr-2.5-full \
--full-libri 1 \ --full-libri 1 \
--max-duration 250 \ --max-duration 180 \
--lr-factor 2.5 --lr-factor 2.5
``` ```
The decoding command is: The decoding command is:
``` ```
epoch=26 epoch=34
avg=12 avg=11
./transducer/decode.py \ ./transducer/decode.py \
--epoch $epoch \ --epoch $epoch \
@ -102,7 +105,7 @@ avg=12
--max-duration 100 --max-duration 100
``` ```
You can find the tensorboard log at: <https://tensorboard.dev/experiment/PYIbeD6zRJez1ViXaRqqeg/> You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3xqTpyVmWi5FnWjrA>
### LibriSpeech BPE training results (Conformer-CTC) ### LibriSpeech BPE training results (Conformer-CTC)

View File

@ -111,7 +111,6 @@ def beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) sos = torch.tensor([blank_id], device=device).reshape(1, 1)
@ -192,7 +191,7 @@ def beam_search(
# Second, choose other labels # Second, choose other labels
for i, v in enumerate(log_prob.tolist()): for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id): if i == blank_id:
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v

View File

@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout, dropout=dropout,
normalize_before=normalize_before, normalize_before=normalize_before,
vgg_frontend=vgg_frontend, vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
) )
self.encoder_pos = RelPositionalEncoding(d_model, dropout) self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
""" """
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels, groups=channels,
bias=bias, bias=bias,
) )
self.norm = nn.BatchNorm1d(channels) self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d( self.pointwise_conv2 = nn.Conv1d(
channels, channels,
channels, channels,
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv # 1D Depthwise Conv
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.activation(self.norm(x)) # x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time) x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -70,14 +70,14 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=26, default=34,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=12, default=11,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -129,10 +129,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params # decoder params
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"num_decoder_layers": 4, "num_decoder_layers": 2,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -151,7 +150,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -161,7 +159,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -401,7 +398,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -27,7 +27,6 @@ class Decoder(nn.Module):
vocab_size: int, vocab_size: int,
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
sos_id: int,
num_layers: int, num_layers: int,
hidden_dim: int, hidden_dim: int,
output_dim: int, output_dim: int,
@ -42,8 +41,6 @@ class Decoder(nn.Module):
Dimension of the input embedding. Dimension of the input embedding.
blank_id: blank_id:
The ID of the blank symbol. The ID of the blank symbol.
sos_id:
The ID of the SOS symbol.
num_layers: num_layers:
Number of LSTM layers. Number of LSTM layers.
hidden_dim: hidden_dim:
@ -71,7 +68,6 @@ class Decoder(nn.Module):
dropout=rnn_dropout, dropout=rnn_dropout,
) )
self.blank_id = blank_id self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim) self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward( def forward(

View File

@ -23,8 +23,8 @@ Usage:
./transducer/export.py \ ./transducer/export.py \
--exp-dir ./transducer/exp \ --exp-dir ./transducer/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--epoch 26 \ --epoch 34 \
--avg 12 --avg 11
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
@ -66,7 +66,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=26, default=34,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
@ -74,7 +74,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=12, default=11,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -119,10 +119,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params # decoder params
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"num_decoder_layers": 4, "num_decoder_layers": 2,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -140,7 +139,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -150,7 +148,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -199,7 +196,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -16,7 +16,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module): class Joiner(nn.Module):
@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C) # Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = F.relu(logit) logit = torch.tanh(logit)
output = self.output_linear(logit) output = self.output_linear(logit)

View File

@ -49,7 +49,7 @@ class Transducer(nn.Module):
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`. one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
@ -58,7 +58,6 @@ class Transducer(nn.Module):
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface) assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id") assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
@ -97,8 +96,7 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id sos_y = add_sos(y, sos_id=blank_id)
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)

View File

@ -116,10 +116,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params # decoder params
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"num_decoder_layers": 4, "num_decoder_layers": 2,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -137,7 +136,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -147,7 +145,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -213,7 +210,6 @@ def main():
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")

View File

@ -36,7 +36,6 @@ def test_conformer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
N = 3 N = 3
T = 100 T = 100

View File

@ -29,7 +29,6 @@ from decoder import Decoder
def test_decoder(): def test_decoder():
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
sos_id = 2
embedding_dim = 128 embedding_dim = 128
num_layers = 2 num_layers = 2
hidden_dim = 6 hidden_dim = 6
@ -41,7 +40,6 @@ def test_decoder():
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers, num_layers=num_layers,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
output_dim=output_dim, output_dim=output_dim,

View File

@ -39,7 +39,6 @@ def test_transducer():
# decoder params # decoder params
vocab_size = 3 vocab_size = 3
blank_id = 0 blank_id = 0
sos_id = 2
embedding_dim = 128 embedding_dim = 128
num_layers = 2 num_layers = 2
@ -51,14 +50,12 @@ def test_transducer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
decoder = Decoder( decoder = Decoder(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
blank_id=blank_id, blank_id=blank_id,
sos_id=sos_id,
num_layers=num_layers, num_layers=num_layers,
hidden_dim=output_dim, hidden_dim=output_dim,
output_dim=output_dim, output_dim=output_dim,

View File

@ -36,7 +36,6 @@ def test_transformer():
nhead=8, nhead=8,
dim_feedforward=2048, dim_feedforward=2048,
num_encoder_layers=12, num_encoder_layers=12,
use_feat_batchnorm=True,
) )
N = 3 N = 3
T = 100 T = 100

View File

@ -23,7 +23,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer/train.py \ ./transducer/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 35 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer/exp \ --exp-dir transducer/exp \
--full-libri 1 \ --full-libri 1 \
@ -92,7 +92,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=30, default=35,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -171,15 +171,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model. - attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
""" """
params = AttributeDict( params = AttributeDict(
@ -201,13 +196,11 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
"use_feat_batchnorm": True,
# decoder params # decoder params
"decoder_embedding_dim": 1024, "decoder_embedding_dim": 1024,
"num_decoder_layers": 4, "num_decoder_layers": 2,
"decoder_hidden_dim": 512, "decoder_hidden_dim": 512,
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k "warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(), "env_info": get_env_info(),
} }
@ -227,7 +220,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
) )
return encoder return encoder
@ -237,7 +229,6 @@ def get_decoder_model(params: AttributeDict):
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers, num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim, hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -575,7 +566,6 @@ def run(rank, world_size, args):
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> and <sos/eos> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -599,7 +589,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim, model_size=params.attention_dim,
factor=params.lr_factor, factor=params.lr_factor,
warm_step=params.warm_step, warm_step=params.warm_step,
weight_decay=params.weight_decay,
) )
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:

View File

@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None: ) -> None:
""" """
Args: Args:
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm. If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend: vgg_frontend:
True to use vgg style frontend for subsampling. True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
""" """
super().__init__() super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features self.num_features = num_features
self.output_dim = output_dim self.output_dim = output_dim
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
""" """
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x = self.encoder_pos(x) x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)