mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Merge branch 'k2-fsa:master' into master
This commit is contained in:
commit
0baa7026f0
@ -74,24 +74,53 @@ jobs:
|
||||
mkdir tmp
|
||||
cd tmp
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
|
||||
cd ..
|
||||
tree tmp
|
||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav
|
||||
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
|
||||
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
|
||||
|
||||
- name: Run greedy search decoding
|
||||
- name: Run greedy search decoding (max-sym-per-frame 1)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
|
||||
--max-sym-per-frame 1 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 2)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 2 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run greedy search decoding (max-sym-per-frame 3)
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 3 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run beam search decoding
|
||||
shell: bash
|
||||
@ -101,8 +130,22 @@ jobs:
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
- name: Run modified beam search decoding
|
||||
shell: bash
|
||||
run: |
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
cd egs/librispeech/ASR
|
||||
./transducer_stateless/pretrained.py \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
|
||||
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
|
||||
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
|
||||
|
||||
@ -80,16 +80,16 @@ We provide a Colab notebook to run a pre-trained RNN-T conformer model: [](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing)
|
||||
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
|
||||
|
||||
### Aishell
|
||||
|
||||
|
||||
@ -82,17 +82,17 @@ class Decoder(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embedding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embeding_out = F.pad(
|
||||
embeding_out, pad=(self.context_size - 1, 0)
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embeding_out.size(-1) == self.context_size
|
||||
embeding_out = self.conv(embeding_out)
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
return embeding_out
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
return embedding_out
|
||||
|
||||
@ -48,6 +48,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search
|
||||
from conformer import Conformer
|
||||
@ -57,10 +58,10 @@ from joiner import Joiner
|
||||
from model import Transducer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -150,7 +151,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
@ -204,7 +204,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
@ -219,7 +219,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -229,7 +229,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -237,7 +237,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
@ -4,62 +4,73 @@
|
||||
|
||||
#### Conformer encoder + embedding decoder
|
||||
|
||||
Using commit `4c1b3665ee6efb935f4dd93a80ff0e154b13efb6`.
|
||||
Using commit `a8150021e01d34ecbd6198fe03a57eacf47a16f2`.
|
||||
|
||||
Conformer encoder + non-current decoder. The decoder
|
||||
Conformer encoder + non-recurrent decoder. The decoder
|
||||
contains only an embedding layer and a Conv1d (with kernel size 2).
|
||||
|
||||
The WERs are
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|---------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 |
|
||||
| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 |
|
||||
| | test-clean | test-other | comment |
|
||||
|-------------------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search (max sym per frame 1) | 2.68 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
|
||||
| greedy search (max sym per frame 2) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
|
||||
| greedy search (max sym per frame 3) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
|
||||
| modified beam search (beam size 4) | 2.67 | 6.64 | --epoch 61, --avg 18, --max-duration 100 |
|
||||
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```
|
||||
cd egs/librispeech/ASR/
|
||||
./prepare.sh
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 76 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp-full \
|
||||
--full-libri 1 \
|
||||
--max-duration 250 \
|
||||
--lr-factor 3
|
||||
--max-duration 300 \
|
||||
--lr-factor 5 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--modified-transducer-prob 0.25
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
<https://tensorboard.dev/experiment/qGdqzHnxS0WJ695OXfZDzA/#scalars&_smoothingWeight=0>
|
||||
<https://tensorboard.dev/experiment/qgvWkbF2R46FYA6ZMNmOjA/#scalars>
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
epoch=71
|
||||
avg=15
|
||||
epoch=61
|
||||
avg=18
|
||||
|
||||
## greedy search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir transducer_stateless/exp-full \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--max-duration 100
|
||||
for sym in 1 2 3; do
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir transducer_stateless/exp-full \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--max-duration 100 \
|
||||
--max-sym-per-frame $sym
|
||||
done
|
||||
|
||||
## modified beam search
|
||||
|
||||
## beam search
|
||||
./transducer_stateless/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir transducer_stateless/exp-full \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--context-size 2 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
You can find a pretrained model by visiting
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10>
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07>
|
||||
|
||||
|
||||
#### Conformer encoder + LSTM decoder
|
||||
|
||||
@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
@ -123,6 +124,15 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decoder-layers",
|
||||
type=int,
|
||||
default=6,
|
||||
help="""Number of decoder layer of transformer decoder.
|
||||
Setting this to 0 will not create the decoder at all (pure CTC model)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-factor",
|
||||
type=float,
|
||||
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
|
||||
"use_feat_batchnorm": True,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_decoder_layers": 6,
|
||||
# parameters for loss
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
@ -357,9 +366,17 @@ def compute_loss(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
|
||||
# Works with a BPE model
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
decoding_graph = graph_compiler.compile(token_ids)
|
||||
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
|
||||
# Works with a phone lexicon
|
||||
decoding_graph = graph_compiler.compile(texts)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type of graph compiler: {type(graph_compiler)}"
|
||||
)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
@ -584,12 +601,38 @@ def run(rank, world_size, args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
if "lang_bpe" in params.lang_dir:
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
elif "lang_phone" in params.lang_dir:
|
||||
assert params.att_rate == 0, (
|
||||
"Attention decoder training does not support phone lang dirs "
|
||||
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
|
||||
"for pure CTC training when using a phone-based lang dir."
|
||||
)
|
||||
assert params.num_decoder_layers == 0, (
|
||||
"Attention decoder training does not support phone lang dirs "
|
||||
"at this time due to a missing <sos/eos> symbol. "
|
||||
"Set --num-decoder-layers=0 for pure CTC training when using "
|
||||
"a phone-based lang dir."
|
||||
)
|
||||
graph_compiler = CtcTrainingGraphCompiler(
|
||||
lexicon,
|
||||
device=device,
|
||||
)
|
||||
# Manually add the sos/eos ID with their default values
|
||||
# from the BPE recipe which we're adapting here.
|
||||
graph_compiler.sos_id = 1
|
||||
graph_compiler.eos_id = 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type of lang dir (we expected it to have "
|
||||
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
|
||||
)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = Conformer(
|
||||
@ -607,7 +650,9 @@ def run(rank, world_size, args):
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
# Note: find_unused_parameters=True is needed in case we
|
||||
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
|
||||
@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
blank_id = model.decoder.blank_id
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
|
||||
1, 1
|
||||
)
|
||||
decoder_out, (h, c) = model.decoder(sos)
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
||||
@ -99,6 +99,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
|
||||
|
||||
@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
|
||||
blank_id = model.decoder.blank_id
|
||||
device = model.device
|
||||
|
||||
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
|
||||
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
|
||||
1, 1
|
||||
)
|
||||
decoder_out, (h, c) = model.decoder(sos)
|
||||
T = encoder_out.size(1)
|
||||
t = 0
|
||||
|
||||
@ -101,6 +101,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=sos_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out, _ = self.decoder(sos_y_padded)
|
||||
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
@ -48,7 +47,7 @@ def greedy_search(
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
@ -108,8 +107,9 @@ class Hypothesis:
|
||||
# Newly predicted tokens are appended to `ys`.
|
||||
ys: List[int]
|
||||
|
||||
# The log prob of ys
|
||||
log_prob: float
|
||||
# The log prob of ys.
|
||||
# It contains only one entry.
|
||||
log_prob: torch.Tensor
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
@ -145,8 +145,10 @@ class HypothesisList(object):
|
||||
"""
|
||||
key = hyp.key
|
||||
if key in self:
|
||||
old_hyp = self._data[key]
|
||||
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob)
|
||||
old_hyp = self._data[key] # shallow copy
|
||||
torch.logaddexp(
|
||||
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
|
||||
)
|
||||
else:
|
||||
self._data[key] = hyp
|
||||
|
||||
@ -184,7 +186,7 @@ class HypothesisList(object):
|
||||
assert key in self, f"{key} does not exist"
|
||||
del self._data[key]
|
||||
|
||||
def filter(self, threshold: float) -> "HypothesisList":
|
||||
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||
|
||||
Caution:
|
||||
@ -312,6 +314,113 @@ def run_joiner(
|
||||
return log_prob
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
) -> List[int]:
|
||||
"""It limits the maximum number of symbols per frame to 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
An instance of `Transducer`.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
|
||||
beam:
|
||||
Beam size.
|
||||
Returns:
|
||||
Return the decoded result.
|
||||
"""
|
||||
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
# support only batch_size == 1 for now
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
|
||||
T = encoder_out.size(1)
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
decoder_out_len = torch.tensor([1])
|
||||
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
A = list(B)
|
||||
B = HypothesisList()
|
||||
|
||||
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
|
||||
# ys_log_probs is of shape (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyp in A],
|
||||
device=device,
|
||||
)
|
||||
# decoder_input is of shape (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
|
||||
|
||||
current_encoder_out = current_encoder_out.expand(
|
||||
decoder_out.size(0), 1, -1
|
||||
)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len.expand(decoder_out.size(0)),
|
||||
decoder_out_len.expand(decoder_out.size(0)),
|
||||
)
|
||||
# logits is of shape (num_hyps, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
log_probs = log_probs.reshape(-1)
|
||||
topk_log_probs, topk_indexes = log_probs.topk(beam)
|
||||
|
||||
# topk_hyp_indexes are indexes into `A`
|
||||
topk_hyp_indexes = topk_indexes // logits.size(-1)
|
||||
topk_token_indexes = topk_indexes % logits.size(-1)
|
||||
|
||||
topk_hyp_indexes = topk_hyp_indexes.tolist()
|
||||
topk_token_indexes = topk_token_indexes.tolist()
|
||||
|
||||
for i in range(len(topk_hyp_indexes)):
|
||||
hyp = A[topk_hyp_indexes[i]]
|
||||
new_ys = hyp.ys[:]
|
||||
new_token = topk_token_indexes[i]
|
||||
if new_token != blank_id:
|
||||
new_ys.append(new_token)
|
||||
new_log_prob = topk_log_probs[i]
|
||||
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
|
||||
B.add(new_hyp)
|
||||
|
||||
best_hyp = B.get_most_probable(length_norm=True)
|
||||
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
|
||||
|
||||
return ys
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -351,7 +460,12 @@ def beam_search(
|
||||
t = 0
|
||||
|
||||
B = HypothesisList()
|
||||
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0))
|
||||
B.add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
)
|
||||
)
|
||||
|
||||
max_sym_per_utt = 20000
|
||||
|
||||
@ -371,9 +485,6 @@ def beam_search(
|
||||
|
||||
joint_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
# TODO(fangjun): Implement prefix search to update the `log_prob`
|
||||
# of hypotheses in A
|
||||
|
||||
while True:
|
||||
y_star = A.get_most_probable()
|
||||
A.remove(y_star)
|
||||
@ -396,18 +507,21 @@ def beam_search(
|
||||
|
||||
# First, process the blank symbol
|
||||
skip_log_prob = log_prob[blank_id]
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
|
||||
new_y_star_log_prob = y_star.log_prob + skip_log_prob
|
||||
|
||||
# ys[:] returns a copy of ys
|
||||
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
|
||||
|
||||
# Second, process other non-blank labels
|
||||
values, indices = log_prob.topk(beam + 1)
|
||||
for i, v in zip(indices.tolist(), values.tolist()):
|
||||
for idx in range(values.size(0)):
|
||||
i = indices[idx].item()
|
||||
if i == blank_id:
|
||||
continue
|
||||
|
||||
new_ys = y_star.ys + [i]
|
||||
new_log_prob = y_star.log_prob + v
|
||||
|
||||
new_log_prob = y_star.log_prob + values[idx]
|
||||
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
|
||||
|
||||
# Check whether B contains more than "beam" elements more probable
|
||||
|
||||
@ -46,7 +46,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import beam_search, greedy_search
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -104,6 +104,7 @@ def get_parser():
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
@ -111,7 +112,8 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Used only when --decoding-method is beam_search",
|
||||
help="""Used only when --decoding-method is
|
||||
beam_search or modified_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -125,7 +127,8 @@ def get_parser():
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of symbols per frame",
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -256,6 +259,10 @@ def decode_one_batch(
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
@ -389,11 +396,15 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in ("greedy_search", "beam_search")
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.decoding_method == "beam_search":
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
|
||||
@ -75,24 +75,24 @@ class Decoder(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U) with blank prepended.
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embedding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embeding_out = F.pad(
|
||||
embeding_out, pad=(self.context_size - 1, 0)
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embeding_out.size(-1) == self.context_size
|
||||
embeding_out = self.conv(embeding_out)
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
return embeding_out
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
return embedding_out
|
||||
|
||||
@ -48,6 +48,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -62,6 +64,7 @@ class Transducer(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
modified_transducer_prob: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -73,6 +76,8 @@ class Transducer(nn.Module):
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
modified_transducer_prob:
|
||||
The probability to use modified transducer loss.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
"""
|
||||
@ -93,6 +98,7 @@ class Transducer(nn.Module):
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
@ -113,6 +119,16 @@ class Transducer(nn.Module):
|
||||
# reference stage
|
||||
import optimized_transducer
|
||||
|
||||
assert 0 <= modified_transducer_prob <= 1
|
||||
|
||||
if modified_transducer_prob == 0:
|
||||
one_sym_per_frame = False
|
||||
elif random.random() < modified_transducer_prob:
|
||||
# random.random() returns a float in the range [0, 1)
|
||||
one_sym_per_frame = True
|
||||
else:
|
||||
one_sym_per_frame = False
|
||||
|
||||
loss = optimized_transducer.transducer_loss(
|
||||
logits=logits,
|
||||
targets=y_padded,
|
||||
@ -120,6 +136,8 @@ class Transducer(nn.Module):
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="sum",
|
||||
one_sym_per_frame=one_sym_per_frame,
|
||||
from_log_softmax=False,
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
@ -22,10 +22,11 @@ Usage:
|
||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
(1) beam search
|
||||
(2) beam search
|
||||
./transducer_stateless/pretrained.py \
|
||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
@ -34,6 +35,15 @@ Usage:
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
(3) modified beam search
|
||||
./transducer_stateless/pretrained.py \
|
||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
||||
@ -49,8 +59,9 @@ from typing import List
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -90,6 +101,7 @@ def get_parser():
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
@ -107,7 +119,7 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Used only when --method is beam_search",
|
||||
help="Used only when --method is beam_search and modified_beam_search ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -148,7 +160,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
@ -162,7 +174,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -172,7 +184,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -180,7 +192,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
@ -217,6 +229,7 @@ def read_sound_files(
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
@ -300,6 +313,10 @@ def main():
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
|
||||
@ -138,6 +138,17 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modified-transducer-prob",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="""The probability to use modified transducer loss.
|
||||
In modified transduer, it limits the maximum number of symbols
|
||||
per frame to 1. See also the option --max-sym-per-frame in
|
||||
transducer_stateless/decode.py
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -213,7 +224,7 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict):
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
@ -228,7 +239,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
@ -238,7 +249,7 @@ def get_decoder_model(params: AttributeDict):
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
@ -246,7 +257,7 @@ def get_joiner_model(params: AttributeDict):
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict):
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
@ -383,7 +394,12 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||
loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
modified_transducer_prob=params.modified_transducer_prob,
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
|
||||
@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
|
||||
|
||||
return decoding_graph
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of word IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings. Each string consists of space(s)
|
||||
separated words. An example containing two strings is given below:
|
||||
|
||||
['HELLO ICEFALL', 'HELLO k2']
|
||||
Returns:
|
||||
Return a list-of-list of word IDs.
|
||||
"""
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
word_ids.append(self.oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
return word_ids_list
|
||||
|
||||
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||
"""Convert a list of transcript texts to an FsaVec.
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user