mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
0baa7026f0
@ -74,24 +74,53 @@ jobs:
|
|||||||
mkdir tmp
|
mkdir tmp
|
||||||
cd tmp
|
cd tmp
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10
|
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
|
||||||
cd ..
|
cd ..
|
||||||
tree tmp
|
tree tmp
|
||||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
|
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
|
||||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
|
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
|
||||||
|
|
||||||
- name: Run greedy search decoding
|
- name: Run greedy search decoding (max-sym-per-frame 1)
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
export PYTHONPATH=$PWD:PYTHONPATH
|
export PYTHONPATH=$PWD:PYTHONPATH
|
||||||
cd egs/librispeech/ASR
|
cd egs/librispeech/ASR
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
|
--max-sym-per-frame 1 \
|
||||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
- name: Run greedy search decoding (max-sym-per-frame 2)
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:PYTHONPATH
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
./transducer_stateless/pretrained.py \
|
||||||
|
--method greedy_search \
|
||||||
|
--max-sym-per-frame 2 \
|
||||||
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||||
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
- name: Run greedy search decoding (max-sym-per-frame 3)
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:PYTHONPATH
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
./transducer_stateless/pretrained.py \
|
||||||
|
--method greedy_search \
|
||||||
|
--max-sym-per-frame 3 \
|
||||||
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||||
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
- name: Run beam search decoding
|
- name: Run beam search decoding
|
||||||
shell: bash
|
shell: bash
|
||||||
@ -101,8 +130,22 @@ jobs:
|
|||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||||
|
|
||||||
|
- name: Run modified beam search decoding
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
./transducer_stateless/pretrained.py \
|
||||||
|
--method modified_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||||
|
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||||
|
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||||
|
|||||||
@ -80,16 +80,16 @@ We provide a Colab notebook to run a pre-trained RNN-T conformer model: [](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing)
|
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
|
||||||
|
|
||||||
### Aishell
|
### Aishell
|
||||||
|
|
||||||
|
|||||||
@ -82,17 +82,17 @@ class Decoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
"""
|
"""
|
||||||
embeding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embeding_out = F.pad(
|
embedding_out = F.pad(
|
||||||
embeding_out, pad=(self.context_size - 1, 0)
|
embedding_out, pad=(self.context_size - 1, 0)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
assert embeding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embeding_out = self.conv(embeding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
return embeding_out
|
return embedding_out
|
||||||
|
|||||||
@ -48,6 +48,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.encoder_out_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|||||||
@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -57,10 +58,10 @@ from joiner import Joiner
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.utils import AttributeDict
|
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
|
from icefall.env import get_env_info
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import AttributeDict
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -150,7 +151,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.encoder_out_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict):
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|||||||
@ -204,7 +204,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
@ -219,7 +219,7 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
@ -229,7 +229,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.encoder_out_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
@ -237,7 +237,7 @@ def get_joiner_model(params: AttributeDict):
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|||||||
@ -4,62 +4,73 @@
|
|||||||
|
|
||||||
#### Conformer encoder + embedding decoder
|
#### Conformer encoder + embedding decoder
|
||||||
|
|
||||||
Using commit `4c1b3665ee6efb935f4dd93a80ff0e154b13efb6`.
|
Using commit `a8150021e01d34ecbd6198fe03a57eacf47a16f2`.
|
||||||
|
|
||||||
Conformer encoder + non-current decoder. The decoder
|
Conformer encoder + non-recurrent decoder. The decoder
|
||||||
contains only an embedding layer and a Conv1d (with kernel size 2).
|
contains only an embedding layer and a Conv1d (with kernel size 2).
|
||||||
|
|
||||||
The WERs are
|
The WERs are
|
||||||
|
|
||||||
| | test-clean | test-other | comment |
|
| | test-clean | test-other | comment |
|
||||||
|---------------------------|------------|------------|------------------------------------------|
|
|-------------------------------------|------------|------------|------------------------------------------|
|
||||||
| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 |
|
| greedy search (max sym per frame 1) | 2.68 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
|
||||||
| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 |
|
| greedy search (max sym per frame 2) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
|
||||||
|
| greedy search (max sym per frame 3) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
|
||||||
|
| modified beam search (beam size 4) | 2.67 | 6.64 | --epoch 61, --avg 18, --max-duration 100 |
|
||||||
|
|
||||||
|
|
||||||
The training command for reproducing is given below:
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
cd egs/librispeech/ASR/
|
||||||
|
./prepare.sh
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./transducer_stateless/train.py \
|
./transducer_stateless/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 76 \
|
--num-epochs 76 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir transducer_stateless/exp-full \
|
--exp-dir transducer_stateless/exp-full \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 250 \
|
--max-duration 300 \
|
||||||
--lr-factor 3
|
--lr-factor 5 \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--modified-transducer-prob 0.25
|
||||||
```
|
```
|
||||||
|
|
||||||
The tensorboard training log can be found at
|
The tensorboard training log can be found at
|
||||||
<https://tensorboard.dev/experiment/qGdqzHnxS0WJ695OXfZDzA/#scalars&_smoothingWeight=0>
|
<https://tensorboard.dev/experiment/qgvWkbF2R46FYA6ZMNmOjA/#scalars>
|
||||||
|
|
||||||
The decoding command is:
|
The decoding command is:
|
||||||
```
|
```
|
||||||
epoch=71
|
epoch=61
|
||||||
avg=15
|
avg=18
|
||||||
|
|
||||||
## greedy search
|
## greedy search
|
||||||
./transducer_stateless/decode.py \
|
for sym in 1 2 3; do
|
||||||
--epoch $epoch \
|
./transducer_stateless/decode.py \
|
||||||
--avg $avg \
|
--epoch $epoch \
|
||||||
--exp-dir transducer_stateless/exp-full \
|
--avg $avg \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--exp-dir transducer_stateless/exp-full \
|
||||||
--max-duration 100
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--max-duration 100 \
|
||||||
|
--max-sym-per-frame $sym
|
||||||
|
done
|
||||||
|
|
||||||
|
## modified beam search
|
||||||
|
|
||||||
## beam search
|
|
||||||
./transducer_stateless/decode.py \
|
./transducer_stateless/decode.py \
|
||||||
--epoch $epoch \
|
--epoch $epoch \
|
||||||
--avg $avg \
|
--avg $avg \
|
||||||
--exp-dir transducer_stateless/exp-full \
|
--exp-dir transducer_stateless/exp-full \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--context-size 2 \
|
||||||
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
```
|
```
|
||||||
|
|
||||||
You can find a pretrained model by visiting
|
You can find a pretrained model by visiting
|
||||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10>
|
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07>
|
||||||
|
|
||||||
|
|
||||||
#### Conformer encoder + LSTM decoder
|
#### Conformer encoder + LSTM decoder
|
||||||
|
|||||||
@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
|
|||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
@ -123,6 +124,15 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-decoder-layers",
|
||||||
|
type=int,
|
||||||
|
default=6,
|
||||||
|
help="""Number of decoder layer of transformer decoder.
|
||||||
|
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-factor",
|
"--lr-factor",
|
||||||
type=float,
|
type=float,
|
||||||
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
|
|||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"num_decoder_layers": 6,
|
|
||||||
# parameters for loss
|
# parameters for loss
|
||||||
"beam_size": 10,
|
"beam_size": 10,
|
||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
@ -357,9 +366,17 @@ def compute_loss(
|
|||||||
supervisions, subsampling_factor=params.subsampling_factor
|
supervisions, subsampling_factor=params.subsampling_factor
|
||||||
)
|
)
|
||||||
|
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
|
||||||
|
# Works with a BPE model
|
||||||
decoding_graph = graph_compiler.compile(token_ids)
|
token_ids = graph_compiler.texts_to_ids(texts)
|
||||||
|
decoding_graph = graph_compiler.compile(token_ids)
|
||||||
|
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
|
||||||
|
# Works with a phone lexicon
|
||||||
|
decoding_graph = graph_compiler.compile(texts)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported type of graph compiler: {type(graph_compiler)}"
|
||||||
|
)
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
nnet_output,
|
nnet_output,
|
||||||
@ -584,12 +601,38 @@ def run(rank, world_size, args):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
if "lang_bpe" in params.lang_dir:
|
||||||
params.lang_dir,
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
device=device,
|
params.lang_dir,
|
||||||
sos_token="<sos/eos>",
|
device=device,
|
||||||
eos_token="<sos/eos>",
|
sos_token="<sos/eos>",
|
||||||
)
|
eos_token="<sos/eos>",
|
||||||
|
)
|
||||||
|
elif "lang_phone" in params.lang_dir:
|
||||||
|
assert params.att_rate == 0, (
|
||||||
|
"Attention decoder training does not support phone lang dirs "
|
||||||
|
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
|
||||||
|
"for pure CTC training when using a phone-based lang dir."
|
||||||
|
)
|
||||||
|
assert params.num_decoder_layers == 0, (
|
||||||
|
"Attention decoder training does not support phone lang dirs "
|
||||||
|
"at this time due to a missing <sos/eos> symbol. "
|
||||||
|
"Set --num-decoder-layers=0 for pure CTC training when using "
|
||||||
|
"a phone-based lang dir."
|
||||||
|
)
|
||||||
|
graph_compiler = CtcTrainingGraphCompiler(
|
||||||
|
lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Manually add the sos/eos ID with their default values
|
||||||
|
# from the BPE recipe which we're adapting here.
|
||||||
|
graph_compiler.sos_id = 1
|
||||||
|
graph_compiler.eos_id = 1
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported type of lang dir (we expected it to have "
|
||||||
|
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = Conformer(
|
model = Conformer(
|
||||||
@ -607,7 +650,9 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
# Note: find_unused_parameters=True is needed in case we
|
||||||
|
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
|
||||||
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer = Noam(
|
optimizer = Noam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
|
|||||||
@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
|
||||||
|
1, 1
|
||||||
|
)
|
||||||
decoder_out, (h, c) = model.decoder(sos)
|
decoder_out, (h, c) = model.decoder(sos)
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
t = 0
|
t = 0
|
||||||
|
|||||||
@ -99,6 +99,7 @@ class Transducer(nn.Module):
|
|||||||
sos_y = add_sos(y, sos_id=blank_id)
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||||
|
|
||||||
decoder_out, _ = self.decoder(sos_y_padded)
|
decoder_out, _ = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
|||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
|
||||||
|
1, 1
|
||||||
|
)
|
||||||
decoder_out, (h, c) = model.decoder(sos)
|
decoder_out, (h, c) = model.decoder(sos)
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
t = 0
|
t = 0
|
||||||
|
|||||||
@ -101,6 +101,7 @@ class Transducer(nn.Module):
|
|||||||
sos_y = add_sos(y, sos_id=sos_id)
|
sos_y = add_sos(y, sos_id=sos_id)
|
||||||
|
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||||
|
|
||||||
decoder_out, _ = self.decoder(sos_y_padded)
|
decoder_out, _ = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
|
|
||||||
@ -48,7 +47,7 @@ def greedy_search(
|
|||||||
device = model.device
|
device = model.device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size, device=device
|
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||||
).reshape(1, context_size)
|
).reshape(1, context_size)
|
||||||
|
|
||||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
@ -108,8 +107,9 @@ class Hypothesis:
|
|||||||
# Newly predicted tokens are appended to `ys`.
|
# Newly predicted tokens are appended to `ys`.
|
||||||
ys: List[int]
|
ys: List[int]
|
||||||
|
|
||||||
# The log prob of ys
|
# The log prob of ys.
|
||||||
log_prob: float
|
# It contains only one entry.
|
||||||
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
@ -145,8 +145,10 @@ class HypothesisList(object):
|
|||||||
"""
|
"""
|
||||||
key = hyp.key
|
key = hyp.key
|
||||||
if key in self:
|
if key in self:
|
||||||
old_hyp = self._data[key]
|
old_hyp = self._data[key] # shallow copy
|
||||||
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
|
torch.logaddexp(
|
||||||
|
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._data[key] = hyp
|
self._data[key] = hyp
|
||||||
|
|
||||||
@ -184,7 +186,7 @@ class HypothesisList(object):
|
|||||||
assert key in self, f"{key} does not exist"
|
assert key in self, f"{key} does not exist"
|
||||||
del self._data[key]
|
del self._data[key]
|
||||||
|
|
||||||
def filter(self, threshold: float) -> "HypothesisList":
|
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||||
|
|
||||||
Caution:
|
Caution:
|
||||||
@ -312,6 +314,113 @@ def run_joiner(
|
|||||||
return log_prob
|
return log_prob
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search(
|
||||||
|
model: Transducer,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
) -> List[int]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||||
|
beam:
|
||||||
|
Beam size.
|
||||||
|
Returns:
|
||||||
|
Return the decoded result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
# support only batch_size == 1 for now
|
||||||
|
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
device = model.device
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[blank_id] * context_size, device=device
|
||||||
|
).reshape(1, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
|
||||||
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
|
B = HypothesisList()
|
||||||
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out_len = torch.tensor([1])
|
||||||
|
decoder_out_len = torch.tensor([1])
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
# fmt: off
|
||||||
|
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||||
|
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||||
|
# fmt: on
|
||||||
|
A = list(B)
|
||||||
|
B = HypothesisList()
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||||
|
# ys_log_probs is of shape (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyp in A],
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# decoder_input is of shape (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
|
||||||
|
|
||||||
|
current_encoder_out = current_encoder_out.expand(
|
||||||
|
decoder_out.size(0), 1, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
encoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
decoder_out_len.expand(decoder_out.size(0)),
|
||||||
|
)
|
||||||
|
# logits is of shape (num_hyps, vocab_size)
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||||
|
|
||||||
|
# topk_hyp_indexes are indexes into `A`
|
||||||
|
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||||
|
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||||
|
|
||||||
|
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||||
|
topk_token_indexes = topk_token_indexes.tolist()
|
||||||
|
|
||||||
|
for i in range(len(topk_hyp_indexes)):
|
||||||
|
hyp = A[topk_hyp_indexes[i]]
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[i]
|
||||||
|
if new_token != blank_id:
|
||||||
|
new_ys.append(new_token)
|
||||||
|
new_log_prob = topk_log_probs[i]
|
||||||
|
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||||
|
B.add(new_hyp)
|
||||||
|
|
||||||
|
best_hyp = B.get_most_probable(length_norm=True)
|
||||||
|
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||||
|
|
||||||
|
return ys
|
||||||
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
model: Transducer,
|
model: Transducer,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
@ -351,7 +460,12 @@ def beam_search(
|
|||||||
t = 0
|
t = 0
|
||||||
|
|
||||||
B = HypothesisList()
|
B = HypothesisList()
|
||||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
B.add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
max_sym_per_utt = 20000
|
max_sym_per_utt = 20000
|
||||||
|
|
||||||
@ -371,9 +485,6 @@ def beam_search(
|
|||||||
|
|
||||||
joint_cache: Dict[str, torch.Tensor] = {}
|
joint_cache: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
|
||||||
# of hypotheses in A
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
y_star = A.get_most_probable()
|
y_star = A.get_most_probable()
|
||||||
A.remove(y_star)
|
A.remove(y_star)
|
||||||
@ -396,18 +507,21 @@ def beam_search(
|
|||||||
|
|
||||||
# First, process the blank symbol
|
# First, process the blank symbol
|
||||||
skip_log_prob = log_prob[blank_id]
|
skip_log_prob = log_prob[blank_id]
|
||||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||||
|
|
||||||
# ys[:] returns a copy of ys
|
# ys[:] returns a copy of ys
|
||||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||||
|
|
||||||
# Second, process other non-blank labels
|
# Second, process other non-blank labels
|
||||||
values, indices = log_prob.topk(beam + 1)
|
values, indices = log_prob.topk(beam + 1)
|
||||||
for i, v in zip(indices.tolist(), values.tolist()):
|
for idx in range(values.size(0)):
|
||||||
|
i = indices[idx].item()
|
||||||
if i == blank_id:
|
if i == blank_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_ys = y_star.ys + [i]
|
new_ys = y_star.ys + [i]
|
||||||
new_log_prob = y_star.log_prob + v
|
|
||||||
|
new_log_prob = y_star.log_prob + values[idx]
|
||||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||||
|
|
||||||
# Check whether B contains more than "beam" elements more probable
|
# Check whether B contains more than "beam" elements more probable
|
||||||
|
|||||||
@ -46,7 +46,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -104,6 +104,7 @@ def get_parser():
|
|||||||
help="""Possible values are:
|
help="""Possible values are:
|
||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,7 +112,8 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --decoding-method is beam_search",
|
help="""Used only when --decoding-method is
|
||||||
|
beam_search or modified_beam_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -125,7 +127,8 @@ def get_parser():
|
|||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=3,
|
||||||
help="Maximum number of symbols per frame",
|
help="""Maximum number of symbols per frame.
|
||||||
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -256,6 +259,10 @@ def decode_one_batch(
|
|||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
)
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
hyp = modified_beam_search(
|
||||||
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
@ -389,11 +396,15 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
assert params.decoding_method in ("greedy_search", "beam_search")
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"beam_search",
|
||||||
|
"modified_beam_search",
|
||||||
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
if params.decoding_method == "beam_search":
|
if "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
|
|||||||
@ -75,24 +75,24 @@ class Decoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
y:
|
y:
|
||||||
A 2-D tensor of shape (N, U) with blank prepended.
|
A 2-D tensor of shape (N, U).
|
||||||
need_pad:
|
need_pad:
|
||||||
True to left pad the input. Should be True during training.
|
True to left pad the input. Should be True during training.
|
||||||
False to not pad the input. Should be False during inference.
|
False to not pad the input. Should be False during inference.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
"""
|
"""
|
||||||
embeding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embeding_out = F.pad(
|
embedding_out = F.pad(
|
||||||
embeding_out, pad=(self.context_size - 1, 0)
|
embedding_out, pad=(self.context_size - 1, 0)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
assert embeding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embeding_out = self.conv(embeding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
return embeding_out
|
return embedding_out
|
||||||
|
|||||||
@ -48,6 +48,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.encoder_out_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|||||||
@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -62,6 +64,7 @@ class Transducer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
|
modified_transducer_prob: float = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -73,6 +76,8 @@ class Transducer(nn.Module):
|
|||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
|
modified_transducer_prob:
|
||||||
|
The probability to use modified transducer loss.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
"""
|
"""
|
||||||
@ -93,6 +98,7 @@ class Transducer(nn.Module):
|
|||||||
sos_y = add_sos(y, sos_id=blank_id)
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||||
|
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
@ -113,6 +119,16 @@ class Transducer(nn.Module):
|
|||||||
# reference stage
|
# reference stage
|
||||||
import optimized_transducer
|
import optimized_transducer
|
||||||
|
|
||||||
|
assert 0 <= modified_transducer_prob <= 1
|
||||||
|
|
||||||
|
if modified_transducer_prob == 0:
|
||||||
|
one_sym_per_frame = False
|
||||||
|
elif random.random() < modified_transducer_prob:
|
||||||
|
# random.random() returns a float in the range [0, 1)
|
||||||
|
one_sym_per_frame = True
|
||||||
|
else:
|
||||||
|
one_sym_per_frame = False
|
||||||
|
|
||||||
loss = optimized_transducer.transducer_loss(
|
loss = optimized_transducer.transducer_loss(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
targets=y_padded,
|
targets=y_padded,
|
||||||
@ -120,6 +136,8 @@ class Transducer(nn.Module):
|
|||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
blank=blank_id,
|
blank=blank_id,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
|
one_sym_per_frame=one_sym_per_frame,
|
||||||
|
from_log_softmax=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
@ -22,10 +22,11 @@ Usage:
|
|||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav \
|
||||||
|
|
||||||
(1) beam search
|
(2) beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
@ -34,6 +35,15 @@ Usage:
|
|||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav \
|
||||||
|
|
||||||
|
(3) modified beam search
|
||||||
|
./transducer_stateless/pretrained.py \
|
||||||
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||||
|
--method modified_beam_search \
|
||||||
|
--beam-size 4 \
|
||||||
|
/path/to/foo.wav \
|
||||||
|
/path/to/bar.wav \
|
||||||
|
|
||||||
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
|
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
|
||||||
|
|
||||||
Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
||||||
@ -49,8 +59,9 @@ from typing import List
|
|||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -90,6 +101,7 @@ def get_parser():
|
|||||||
help="""Possible values are:
|
help="""Possible values are:
|
||||||
- greedy_search
|
- greedy_search
|
||||||
- beam_search
|
- beam_search
|
||||||
|
- modified_beam_search
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -107,7 +119,7 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="Used only when --method is beam_search",
|
help="Used only when --method is beam_search and modified_beam_search ",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -148,7 +160,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.encoder_out_dim,
|
output_dim=params.encoder_out_dim,
|
||||||
@ -162,7 +174,7 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
@ -172,7 +184,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.encoder_out_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
@ -180,7 +192,7 @@ def get_joiner_model(params: AttributeDict):
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
@ -217,6 +229,7 @@ def read_sound_files(
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -300,6 +313,10 @@ def main():
|
|||||||
hyp = beam_search(
|
hyp = beam_search(
|
||||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
)
|
)
|
||||||
|
elif params.method == "modified_beam_search":
|
||||||
|
hyp = modified_beam_search(
|
||||||
|
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
|
|||||||
@ -138,6 +138,17 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--modified-transducer-prob",
|
||||||
|
type=float,
|
||||||
|
default=0.25,
|
||||||
|
help="""The probability to use modified transducer loss.
|
||||||
|
In modified transduer, it limits the maximum number of symbols
|
||||||
|
per frame to 1. See also the option --max-sym-per-frame in
|
||||||
|
transducer_stateless/decode.py
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -213,7 +224,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
@ -228,7 +239,7 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
@ -238,7 +249,7 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.encoder_out_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
@ -246,7 +257,7 @@ def get_joiner_model(params: AttributeDict):
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
@ -383,7 +394,12 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
loss = model(
|
||||||
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
y=y,
|
||||||
|
modified_transducer_prob=params.modified_transducer_prob,
|
||||||
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
|||||||
@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
|
|||||||
|
|
||||||
return decoding_graph
|
return decoding_graph
|
||||||
|
|
||||||
|
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||||
|
"""Convert a list of texts to a list-of-list of word IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
It is a list of strings. Each string consists of space(s)
|
||||||
|
separated words. An example containing two strings is given below:
|
||||||
|
|
||||||
|
['HELLO ICEFALL', 'HELLO k2']
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of word IDs.
|
||||||
|
"""
|
||||||
|
word_ids_list = []
|
||||||
|
for text in texts:
|
||||||
|
word_ids = []
|
||||||
|
for word in text.split():
|
||||||
|
if word in self.word_table:
|
||||||
|
word_ids.append(self.word_table[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(self.oov_id)
|
||||||
|
word_ids_list.append(word_ids)
|
||||||
|
return word_ids_list
|
||||||
|
|
||||||
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||||
"""Convert a list of transcript texts to an FsaVec.
|
"""Convert a list of transcript texts to an FsaVec.
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user