mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge branch 'master' of https://github.com/k2-fsa/icefall
This commit is contained in:
commit
234307f33a
@ -14,7 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
name: run-pre-trained-tranducer-stateless
|
name: run-pre-trained-trandsucer-stateless
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@ -74,11 +74,11 @@ jobs:
|
|||||||
mkdir tmp
|
mkdir tmp
|
||||||
cd tmp
|
cd tmp
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27
|
||||||
cd ..
|
cd ..
|
||||||
tree tmp
|
tree tmp
|
||||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav
|
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
|
||||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/*.wav
|
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/*.wav
|
||||||
|
|
||||||
- name: Run greedy search decoding
|
- name: Run greedy search decoding
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -87,11 +87,11 @@ jobs:
|
|||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
|
||||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
- name: Run beam search decoding
|
- name: Run beam search decoding
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -101,8 +101,8 @@ jobs:
|
|||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/exp/pretrained.pt \
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/exp/pretrained.pt \
|
||||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/data/lang_bpe_500/bpe.model \
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/data/lang_bpe_500/bpe.model \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1089-134686-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1089-134686-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-22/test_wavs/1221-135766-0002.wav
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2021-12-27/test_wavs/1221-135766-0002.wav
|
||||||
|
109
.github/workflows/run-pretrained-transducer.yml
vendored
Normal file
109
.github/workflows/run-pretrained-transducer.yml
vendored
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com)
|
||||||
|
|
||||||
|
# See ../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
name: run-pre-trained-transducer
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
run_pre_trained_transducer:
|
||||||
|
if: github.event.label.name == 'ready' || github.event_name == 'push'
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-18.04]
|
||||||
|
python-version: [3.7, 3.8, 3.9]
|
||||||
|
torch: ["1.10.0"]
|
||||||
|
torchaudio: ["0.10.0"]
|
||||||
|
k2-version: ["1.9.dev20211101"]
|
||||||
|
|
||||||
|
fail-fast: false
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v1
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
python3 -m pip install --upgrade pip pytest
|
||||||
|
# numpy 1.20.x does not support python 3.6
|
||||||
|
pip install numpy==1.19
|
||||||
|
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||||
|
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||||
|
|
||||||
|
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
||||||
|
python3 -m pip install kaldifeat
|
||||||
|
# We are in ./icefall and there is a file: requirements.txt in it
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: Install graphviz
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 -m pip install -qq graphviz
|
||||||
|
sudo apt-get -qq install graphviz
|
||||||
|
|
||||||
|
- name: Download pre-trained model
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
sudo apt-get -qq install git-lfs tree sox
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
mkdir tmp
|
||||||
|
cd tmp
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-bpe-500-2021-12-23
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
tree tmp
|
||||||
|
soxi tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
|
||||||
|
ls -lh tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/*.wav
|
||||||
|
|
||||||
|
- name: Run greedy search decoding
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:PYTHONPATH
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
./transducer/pretrained.py \
|
||||||
|
--method greedy_search \
|
||||||
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
|
||||||
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
- name: Run beam search decoding
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
./transducer/pretrained.py \
|
||||||
|
--method beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/exp/pretrained.pt \
|
||||||
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/data/lang_bpe_500/bpe.model \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1089-134686-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-bpe-500-2021-12-23/test_wavs/1221-135766-0002.wav
|
@ -71,7 +71,7 @@ The best WER with greedy search is:
|
|||||||
|
|
||||||
| | test-clean | test-other |
|
| | test-clean | test-other |
|
||||||
|-----|------------|------------|
|
|-----|------------|------------|
|
||||||
| WER | 3.16 | 7.71 |
|
| WER | 3.07 | 7.51 |
|
||||||
|
|
||||||
We provide a Colab notebook to run a pre-trained RNN-T conformer model: [](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
|
We provide a Colab notebook to run a pre-trained RNN-T conformer model: [](https://colab.research.google.com/drive/1_u6yK9jDkPwG_NLrZMN2XK7Aeq4suMO2?usp=sharing)
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ The best WER using beam search with beam size 4 is:
|
|||||||
|
|
||||||
| | test-clean | test-other |
|
| | test-clean | test-other |
|
||||||
|-----|------------|------------|
|
|-----|------------|------------|
|
||||||
| WER | 2.92 | 7.37 |
|
| WER | 2.83 | 7.19 |
|
||||||
|
|
||||||
Note: No auxiliary losses are used in the training and no LMs are used
|
Note: No auxiliary losses are used in the training and no LMs are used
|
||||||
in the decoding.
|
in the decoding.
|
||||||
|
@ -2,7 +2,10 @@
|
|||||||
|
|
||||||
### LibriSpeech BPE training results (Transducer)
|
### LibriSpeech BPE training results (Transducer)
|
||||||
|
|
||||||
#### 2021-12-22
|
#### Conformer encoder + embedding decoder
|
||||||
|
|
||||||
|
Using commit `TODO`.
|
||||||
|
|
||||||
Conformer encoder + non-current decoder. The decoder
|
Conformer encoder + non-current decoder. The decoder
|
||||||
contains only an embedding layer and a Conv1d (with kernel size 2).
|
contains only an embedding layer and a Conv1d (with kernel size 2).
|
||||||
|
|
||||||
@ -10,12 +13,8 @@ The WERs are
|
|||||||
|
|
||||||
| | test-clean | test-other | comment |
|
| | test-clean | test-other | comment |
|
||||||
|---------------------------|------------|------------|------------------------------------------|
|
|---------------------------|------------|------------|------------------------------------------|
|
||||||
| greedy search | 2.99 | 7.52 | --epoch 20, --avg 10, --max-duration 100 |
|
| greedy search | 2.85 | 7.30 | --epoch 29, --avg 13, --max-duration 100 |
|
||||||
| beam search (beam size 2) | 2.95 | 7.43 | |
|
| beam search (beam size 4) | 2.83 | 7.19 | |
|
||||||
| beam search (beam size 3) | 2.94 | 7.37 | |
|
|
||||||
| beam search (beam size 4) | 2.92 | 7.37 | |
|
|
||||||
| beam search (beam size 5) | 2.93 | 7.38 | |
|
|
||||||
| beam search (beam size 8) | 2.92 | 7.38 | |
|
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
@ -33,12 +32,12 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
```
|
```
|
||||||
|
|
||||||
The tensorboard training log can be found at
|
The tensorboard training log can be found at
|
||||||
<https://tensorboard.dev/experiment/PsJ3LgkEQfOmzedAlYfVeg/#scalars&_smoothingWeight=0>
|
<https://tensorboard.dev/experiment/Mjx7MeTgR3Oyr1yBCwjozw/>
|
||||||
|
|
||||||
The decoding command is:
|
The decoding command is:
|
||||||
```
|
```
|
||||||
epoch=20
|
epoch=29
|
||||||
avg=10
|
avg=13
|
||||||
|
|
||||||
## greedy search
|
## greedy search
|
||||||
./transducer_stateless/decode.py \
|
./transducer_stateless/decode.py \
|
||||||
@ -60,8 +59,8 @@ avg=10
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
#### 2021-12-17
|
#### Conformer encoder + LSTM decoder
|
||||||
Using commit `cb04c8a7509425ab45fae888b0ca71bbbd23f0de`.
|
Using commit `8187d6236c2926500da5ee854f758e621df803cc`.
|
||||||
|
|
||||||
Conformer encoder + LSTM decoder.
|
Conformer encoder + LSTM decoder.
|
||||||
|
|
||||||
@ -69,9 +68,9 @@ The best WER is
|
|||||||
|
|
||||||
| | test-clean | test-other |
|
| | test-clean | test-other |
|
||||||
|-----|------------|------------|
|
|-----|------------|------------|
|
||||||
| WER | 3.16 | 7.71 |
|
| WER | 3.07 | 7.51 |
|
||||||
|
|
||||||
using `--epoch 26 --avg 12` with **greedy search**.
|
using `--epoch 34 --avg 11` with **greedy search**.
|
||||||
|
|
||||||
The training command to reproduce the above WER is:
|
The training command to reproduce the above WER is:
|
||||||
|
|
||||||
@ -80,19 +79,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
./transducer/train.py \
|
./transducer/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 35 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir transducer/exp-lr-2.5-full \
|
--exp-dir transducer/exp-lr-2.5-full \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 250 \
|
--max-duration 180 \
|
||||||
--lr-factor 2.5
|
--lr-factor 2.5
|
||||||
```
|
```
|
||||||
|
|
||||||
The decoding command is:
|
The decoding command is:
|
||||||
|
|
||||||
```
|
```
|
||||||
epoch=26
|
epoch=34
|
||||||
avg=12
|
avg=11
|
||||||
|
|
||||||
./transducer/decode.py \
|
./transducer/decode.py \
|
||||||
--epoch $epoch \
|
--epoch $epoch \
|
||||||
@ -102,7 +101,7 @@ avg=12
|
|||||||
--max-duration 100
|
--max-duration 100
|
||||||
```
|
```
|
||||||
|
|
||||||
You can find the tensorboard log at: <https://tensorboard.dev/experiment/PYIbeD6zRJez1ViXaRqqeg/>
|
You can find the tensorboard log at: <https://tensorboard.dev/experiment/D7NQc3xqTpyVmWi5FnWjrA>
|
||||||
|
|
||||||
|
|
||||||
### LibriSpeech BPE training results (Conformer-CTC)
|
### LibriSpeech BPE training results (Conformer-CTC)
|
||||||
|
@ -111,7 +111,6 @@ def beam_search(
|
|||||||
# support only batch_size == 1 for now
|
# support only batch_size == 1 for now
|
||||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
sos_id = model.decoder.sos_id
|
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||||
@ -192,7 +191,7 @@ def beam_search(
|
|||||||
|
|
||||||
# Second, choose other labels
|
# Second, choose other labels
|
||||||
for i, v in enumerate(log_prob.tolist()):
|
for i, v in enumerate(log_prob.tolist()):
|
||||||
if i in (blank_id, sos_id):
|
if i == blank_id:
|
||||||
continue
|
continue
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
new_log_prob = y_star.log_prob + v
|
||||||
|
@ -56,7 +56,6 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -69,7 +68,6 @@ class Conformer(Transformer):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
normalize_before=normalize_before,
|
normalize_before=normalize_before,
|
||||||
vgg_frontend=vgg_frontend,
|
vgg_frontend=vgg_frontend,
|
||||||
use_feat_batchnorm=use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
@ -107,11 +105,6 @@ class Conformer(Transformer):
|
|||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(channels)
|
self.norm = nn.LayerNorm(channels)
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.activation(self.norm(x))
|
# x is (batch, channels, time)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
@ -70,14 +70,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=26,
|
default=34,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=12,
|
default=11,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -129,10 +129,9 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# decoder params
|
# decoder params
|
||||||
"decoder_embedding_dim": 1024,
|
"decoder_embedding_dim": 1024,
|
||||||
"num_decoder_layers": 4,
|
"num_decoder_layers": 2,
|
||||||
"decoder_hidden_dim": 512,
|
"decoder_hidden_dim": 512,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -151,7 +150,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -161,7 +159,6 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.decoder_embedding_dim,
|
embedding_dim=params.decoder_embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
sos_id=params.sos_id,
|
|
||||||
num_layers=params.num_decoder_layers,
|
num_layers=params.num_decoder_layers,
|
||||||
hidden_dim=params.decoder_hidden_dim,
|
hidden_dim=params.decoder_hidden_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
@ -399,9 +396,8 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.sos_id = sp.piece_to_id("<sos/eos>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
@ -27,7 +27,6 @@ class Decoder(nn.Module):
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
sos_id: int,
|
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
output_dim: int,
|
output_dim: int,
|
||||||
@ -42,8 +41,6 @@ class Decoder(nn.Module):
|
|||||||
Dimension of the input embedding.
|
Dimension of the input embedding.
|
||||||
blank_id:
|
blank_id:
|
||||||
The ID of the blank symbol.
|
The ID of the blank symbol.
|
||||||
sos_id:
|
|
||||||
The ID of the SOS symbol.
|
|
||||||
num_layers:
|
num_layers:
|
||||||
Number of LSTM layers.
|
Number of LSTM layers.
|
||||||
hidden_dim:
|
hidden_dim:
|
||||||
@ -71,7 +68,6 @@ class Decoder(nn.Module):
|
|||||||
dropout=rnn_dropout,
|
dropout=rnn_dropout,
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
self.sos_id = sos_id
|
|
||||||
self.output_linear = nn.Linear(hidden_dim, output_dim)
|
self.output_linear = nn.Linear(hidden_dim, output_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -23,8 +23,8 @@ Usage:
|
|||||||
./transducer/export.py \
|
./transducer/export.py \
|
||||||
--exp-dir ./transducer/exp \
|
--exp-dir ./transducer/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch 26 \
|
--epoch 34 \
|
||||||
--avg 12
|
--avg 11
|
||||||
|
|
||||||
It will generate a file exp_dir/pretrained.pt
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=26,
|
default=34,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
@ -74,7 +74,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=12,
|
default=11,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -119,10 +119,9 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# decoder params
|
# decoder params
|
||||||
"decoder_embedding_dim": 1024,
|
"decoder_embedding_dim": 1024,
|
||||||
"num_decoder_layers": 4,
|
"num_decoder_layers": 2,
|
||||||
"decoder_hidden_dim": 512,
|
"decoder_hidden_dim": 512,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -140,7 +139,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -150,7 +148,6 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.decoder_embedding_dim,
|
embedding_dim=params.decoder_embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
sos_id=params.sos_id,
|
|
||||||
num_layers=params.num_decoder_layers,
|
num_layers=params.num_decoder_layers,
|
||||||
hidden_dim=params.decoder_hidden_dim,
|
hidden_dim=params.decoder_hidden_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
@ -197,9 +194,8 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.sos_id = sp.piece_to_id("<sos/eos>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
@ -48,7 +47,7 @@ class Joiner(nn.Module):
|
|||||||
# Now decoder_out is (N, 1, U, C)
|
# Now decoder_out is (N, 1, U, C)
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
logit = F.relu(logit)
|
logit = torch.tanh(logit)
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
output = self.output_linear(logit)
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class Transducer(nn.Module):
|
|||||||
decoder:
|
decoder:
|
||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
is (N, U) and its output shape is (N, U, C). It should contain
|
is (N, U) and its output shape is (N, U, C). It should contain
|
||||||
two attributes: `blank_id` and `sos_id`.
|
one attribute: `blank_id`.
|
||||||
joiner:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||||
output shape is (N, T, U, C). Note that its output contains
|
output shape is (N, T, U, C). Note that its output contains
|
||||||
@ -58,7 +58,6 @@ class Transducer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface)
|
assert isinstance(encoder, EncoderInterface)
|
||||||
assert hasattr(decoder, "blank_id")
|
assert hasattr(decoder, "blank_id")
|
||||||
assert hasattr(decoder, "sos_id")
|
|
||||||
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
@ -97,8 +96,7 @@ class Transducer(nn.Module):
|
|||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
sos_id = self.decoder.sos_id
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
sos_y = add_sos(y, sos_id=sos_id)
|
|
||||||
|
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
@ -116,10 +116,9 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# decoder params
|
# decoder params
|
||||||
"decoder_embedding_dim": 1024,
|
"decoder_embedding_dim": 1024,
|
||||||
"num_decoder_layers": 4,
|
"num_decoder_layers": 2,
|
||||||
"decoder_hidden_dim": 512,
|
"decoder_hidden_dim": 512,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -137,7 +136,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -147,7 +145,6 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.decoder_embedding_dim,
|
embedding_dim=params.decoder_embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
sos_id=params.sos_id,
|
|
||||||
num_layers=params.num_decoder_layers,
|
num_layers=params.num_decoder_layers,
|
||||||
hidden_dim=params.decoder_hidden_dim,
|
hidden_dim=params.decoder_hidden_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
@ -211,9 +208,8 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.sos_id = sp.piece_to_id("<sos/eos>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
@ -36,7 +36,6 @@ def test_conformer():
|
|||||||
nhead=8,
|
nhead=8,
|
||||||
dim_feedforward=2048,
|
dim_feedforward=2048,
|
||||||
num_encoder_layers=12,
|
num_encoder_layers=12,
|
||||||
use_feat_batchnorm=True,
|
|
||||||
)
|
)
|
||||||
N = 3
|
N = 3
|
||||||
T = 100
|
T = 100
|
||||||
|
@ -29,7 +29,6 @@ from decoder import Decoder
|
|||||||
def test_decoder():
|
def test_decoder():
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
blank_id = 0
|
blank_id = 0
|
||||||
sos_id = 2
|
|
||||||
embedding_dim = 128
|
embedding_dim = 128
|
||||||
num_layers = 2
|
num_layers = 2
|
||||||
hidden_dim = 6
|
hidden_dim = 6
|
||||||
@ -41,7 +40,6 @@ def test_decoder():
|
|||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
blank_id=blank_id,
|
blank_id=blank_id,
|
||||||
sos_id=sos_id,
|
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
output_dim=output_dim,
|
output_dim=output_dim,
|
||||||
|
@ -39,7 +39,6 @@ def test_transducer():
|
|||||||
# decoder params
|
# decoder params
|
||||||
vocab_size = 3
|
vocab_size = 3
|
||||||
blank_id = 0
|
blank_id = 0
|
||||||
sos_id = 2
|
|
||||||
embedding_dim = 128
|
embedding_dim = 128
|
||||||
num_layers = 2
|
num_layers = 2
|
||||||
|
|
||||||
@ -51,14 +50,12 @@ def test_transducer():
|
|||||||
nhead=8,
|
nhead=8,
|
||||||
dim_feedforward=2048,
|
dim_feedforward=2048,
|
||||||
num_encoder_layers=12,
|
num_encoder_layers=12,
|
||||||
use_feat_batchnorm=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
blank_id=blank_id,
|
blank_id=blank_id,
|
||||||
sos_id=sos_id,
|
|
||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
hidden_dim=output_dim,
|
hidden_dim=output_dim,
|
||||||
output_dim=output_dim,
|
output_dim=output_dim,
|
||||||
|
@ -36,7 +36,6 @@ def test_transformer():
|
|||||||
nhead=8,
|
nhead=8,
|
||||||
dim_feedforward=2048,
|
dim_feedforward=2048,
|
||||||
num_encoder_layers=12,
|
num_encoder_layers=12,
|
||||||
use_feat_batchnorm=True,
|
|
||||||
)
|
)
|
||||||
N = 3
|
N = 3
|
||||||
T = 100
|
T = 100
|
||||||
|
@ -23,7 +23,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
|
|
||||||
./transducer/train.py \
|
./transducer/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 35 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir transducer/exp \
|
--exp-dir transducer/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
@ -92,7 +92,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=30,
|
default=35,
|
||||||
help="Number of epochs to train.",
|
help="Number of epochs to train.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -171,15 +171,10 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
|
||||||
input features.
|
|
||||||
|
|
||||||
- attention_dim: Hidden dim for multi-head attention model.
|
- attention_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
- weight_decay: The weight_decay for the optimizer.
|
|
||||||
|
|
||||||
- warm_step: The warm_step for Noam optimizer.
|
- warm_step: The warm_step for Noam optimizer.
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
@ -201,13 +196,11 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# decoder params
|
# decoder params
|
||||||
"decoder_embedding_dim": 1024,
|
"decoder_embedding_dim": 1024,
|
||||||
"num_decoder_layers": 4,
|
"num_decoder_layers": 2,
|
||||||
"decoder_hidden_dim": 512,
|
"decoder_hidden_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"weight_decay": 1e-6,
|
|
||||||
"warm_step": 80000, # For the 100h subset, use 8k
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -227,7 +220,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -237,7 +229,6 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.decoder_embedding_dim,
|
embedding_dim=params.decoder_embedding_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
sos_id=params.sos_id,
|
|
||||||
num_layers=params.num_decoder_layers,
|
num_layers=params.num_decoder_layers,
|
||||||
hidden_dim=params.decoder_hidden_dim,
|
hidden_dim=params.decoder_hidden_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
@ -573,9 +564,8 @@ def run(rank, world_size, args):
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.sos_id = sp.piece_to_id("<sos/eos>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -599,7 +589,6 @@ def run(rank, world_size, args):
|
|||||||
model_size=params.attention_dim,
|
model_size=params.attention_dim,
|
||||||
factor=params.lr_factor,
|
factor=params.lr_factor,
|
||||||
warm_step=params.warm_step,
|
warm_step=params.warm_step,
|
||||||
weight_decay=params.weight_decay,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
|
@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
|
|||||||
If True, use pre-layer norm; False to use post-layer norm.
|
If True, use pre-layer norm; False to use post-layer norm.
|
||||||
vgg_frontend:
|
vgg_frontend:
|
||||||
True to use vgg style frontend for subsampling.
|
True to use vgg style frontend for subsampling.
|
||||||
use_feat_batchnorm:
|
|
||||||
True to use batchnorm for the input layer.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_feat_batchnorm = use_feat_batchnorm
|
|
||||||
if use_feat_batchnorm:
|
|
||||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
|
|||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = self.encoder_pos(x)
|
x = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
@ -22,13 +22,18 @@ import torch
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
|
|
||||||
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
def greedy_search(
|
||||||
|
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
|
||||||
|
) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
model:
|
model:
|
||||||
An instance of `Transducer`.
|
An instance of `Transducer`.
|
||||||
encoder_out:
|
encoder_out:
|
||||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
|
max_sym_per_frame:
|
||||||
|
Maximum number of symbols per frame. If it is set to 0, the WER
|
||||||
|
would be 100%.
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoded result.
|
Return the decoded result.
|
||||||
"""
|
"""
|
||||||
@ -55,10 +60,6 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
# Maximum symbols per utterance.
|
# Maximum symbols per utterance.
|
||||||
max_sym_per_utt = 1000
|
max_sym_per_utt = 1000
|
||||||
|
|
||||||
# If at frame t, it decodes more than this number of symbols,
|
|
||||||
# it will move to the next step t+1
|
|
||||||
max_sym_per_frame = 3
|
|
||||||
|
|
||||||
# symbols per frame
|
# symbols per frame
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
|
|
||||||
@ -66,6 +67,11 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
sym_per_utt = 0
|
sym_per_utt = 0
|
||||||
|
|
||||||
while t < T and sym_per_utt < max_sym_per_utt:
|
while t < T and sym_per_utt < max_sym_per_utt:
|
||||||
|
if sym_per_frame >= max_sym_per_frame:
|
||||||
|
sym_per_frame = 0
|
||||||
|
t += 1
|
||||||
|
continue
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@ -83,8 +89,7 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
|
|
||||||
sym_per_utt += 1
|
sym_per_utt += 1
|
||||||
sym_per_frame += 1
|
sym_per_frame += 1
|
||||||
|
else:
|
||||||
if y == blank_id or sym_per_frame > max_sym_per_frame:
|
|
||||||
sym_per_frame = 0
|
sym_per_frame = 0
|
||||||
t += 1
|
t += 1
|
||||||
hyp = hyp[context_size:] # remove blanks
|
hyp = hyp[context_size:] # remove blanks
|
||||||
|
@ -56,7 +56,6 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -69,7 +68,6 @@ class Conformer(Transformer):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
normalize_before=normalize_before,
|
normalize_before=normalize_before,
|
||||||
vgg_frontend=vgg_frontend,
|
vgg_frontend=vgg_frontend,
|
||||||
use_feat_batchnorm=use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||||
@ -107,11 +105,6 @@ class Conformer(Transformer):
|
|||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
self.norm = nn.BatchNorm1d(channels)
|
self.norm = nn.LayerNorm(channels)
|
||||||
self.pointwise_conv2 = nn.Conv1d(
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.activation(self.norm(x))
|
# x is (batch, channels, time)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
|
|
||||||
|
@ -70,14 +70,14 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=29,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="It specifies the checkpoint to use for decoding."
|
||||||
"Note: Epoch counts from 0.",
|
"Note: Epoch counts from 0.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=10,
|
default=13,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
@ -114,6 +114,20 @@ def get_parser():
|
|||||||
help="Used only when --decoding-method is beam_search",
|
help="Used only when --decoding-method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Maximum number of symbols per frame",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -129,9 +143,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -149,7 +160,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -237,7 +247,11 @@ def decode_one_batch(
|
|||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
elif params.decoding_method == "beam_search":
|
elif params.decoding_method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
@ -381,6 +395,9 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if params.decoding_method == "beam_search":
|
if params.decoding_method == "beam_search":
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
|
else:
|
||||||
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
|
|
||||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
@ -20,13 +20,14 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""This class implements the stateless decoder from the following paper:
|
"""This class modifies the stateless decoder from the following paper:
|
||||||
|
|
||||||
RNN-transducer with stateless prediction network
|
RNN-transducer with stateless prediction network
|
||||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||||
|
|
||||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||||
network.
|
network. Different from the above paper, it adds an extra Conv1d
|
||||||
|
right after the embedding layer.
|
||||||
|
|
||||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||||
"""
|
"""
|
||||||
|
@ -104,6 +104,14 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -119,9 +127,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -138,7 +143,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
@ -48,7 +47,7 @@ class Joiner(nn.Module):
|
|||||||
# Now decoder_out is (N, 1, U, C)
|
# Now decoder_out is (N, 1, U, C)
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
logit = F.relu(logit)
|
logit = torch.tanh(logit)
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
output = self.output_linear(logit)
|
||||||
|
|
||||||
|
@ -110,6 +110,22 @@ def get_parser():
|
|||||||
help="Used only when --method is beam_search",
|
help="Used only when --method is beam_search",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-sym-per-frame",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Maximum number of symbols per frame. Used only when
|
||||||
|
--method is greedy_search.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -126,9 +142,6 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -145,7 +158,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -279,7 +291,11 @@ def main():
|
|||||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if params.method == "greedy_search":
|
if params.method == "greedy_search":
|
||||||
hyp = greedy_search(model=model, encoder_out=encoder_out_i)
|
hyp = greedy_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out_i,
|
||||||
|
max_sym_per_frame=params.max_sym_per_frame,
|
||||||
|
)
|
||||||
elif params.method == "beam_search":
|
elif params.method == "beam_search":
|
||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
@ -130,6 +130,14 @@ def get_parser():
|
|||||||
help="The lr_factor for Noam optimizer",
|
help="The lr_factor for Noam optimizer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -171,15 +179,10 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
|
||||||
input features.
|
|
||||||
|
|
||||||
- attention_dim: Hidden dim for multi-head attention model.
|
- attention_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
- weight_decay: The weight_decay for the optimizer.
|
|
||||||
|
|
||||||
- warm_step: The warm_step for Noam optimizer.
|
- warm_step: The warm_step for Noam optimizer.
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
@ -201,11 +204,7 @@ def get_params() -> AttributeDict:
|
|||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
|
||||||
# parameters for decoder
|
|
||||||
"context_size": 2, # tri-gram
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"weight_decay": 1e-6,
|
|
||||||
"warm_step": 80000, # For the 100h subset, use 8k
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -225,7 +224,6 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=params.vgg_frontend,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -568,7 +566,7 @@ def run(rank, world_size, args):
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
@ -593,7 +591,6 @@ def run(rank, world_size, args):
|
|||||||
model_size=params.attention_dim,
|
model_size=params.attention_dim,
|
||||||
factor=params.lr_factor,
|
factor=params.lr_factor,
|
||||||
warm_step=params.warm_step,
|
warm_step=params.warm_step,
|
||||||
weight_decay=params.weight_decay,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
|
@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
|
|||||||
If True, use pre-layer norm; False to use post-layer norm.
|
If True, use pre-layer norm; False to use post-layer norm.
|
||||||
vgg_frontend:
|
vgg_frontend:
|
||||||
True to use vgg style frontend for subsampling.
|
True to use vgg style frontend for subsampling.
|
||||||
use_feat_batchnorm:
|
|
||||||
True to use batchnorm for the input layer.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_feat_batchnorm = use_feat_batchnorm
|
|
||||||
if use_feat_batchnorm:
|
|
||||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
|
|||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
"""
|
"""
|
||||||
if self.use_feat_batchnorm:
|
|
||||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
|
||||||
x = self.feat_batchnorm(x)
|
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
|
||||||
|
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = self.encoder_pos(x)
|
x = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user