Merge branch 'k2-fsa:master' into master

This commit is contained in:
Mingshuang Luo 2022-02-08 14:22:44 +08:00 committed by GitHub
commit 0baa7026f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 426 additions and 119 deletions

View File

@ -74,24 +74,53 @@ jobs:
mkdir tmp mkdir tmp
cd tmp cd tmp
git lfs install git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10 git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07
cd .. cd ..
tree tmp tree tmp
soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav soxi tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/*.wav ls -lh tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/*.wav
- name: Run greedy search decoding - name: Run greedy search decoding (max-sym-per-frame 1)
shell: bash shell: bash
run: | run: |
export PYTHONPATH=$PWD:PYTHONPATH export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method greedy_search \ --method greedy_search \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ --max-sym-per-frame 1 \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 2)
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 2 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run greedy search decoding (max-sym-per-frame 3)
shell: bash
run: |
export PYTHONPATH=$PWD:PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method greedy_search \
--max-sym-per-frame 3 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run beam search decoding - name: Run beam search decoding
shell: bash shell: bash
@ -101,8 +130,22 @@ jobs:
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/exp/pretrained.pt \ --checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/data/lang_bpe_500/bpe.model \ --bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1089-134686-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0001.wav \ ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10/test_wavs/1221-135766-0002.wav ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav
- name: Run modified beam search decoding
shell: bash
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
cd egs/librispeech/ASR
./transducer_stateless/pretrained.py \
--method modified_beam_search \
--beam-size 4 \
--checkpoint ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/exp/pretrained.pt \
--bpe-model ./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/data/lang_bpe_500/bpe.model \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1089-134686-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0001.wav \
./tmp/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07/test_wavs/1221-135766-0002.wav

View File

@ -80,16 +80,16 @@ We provide a Colab notebook to run a pre-trained RNN-T conformer model: [![Open
Using Conformer as encoder. The decoder consists of 1 embedding layer Using Conformer as encoder. The decoder consists of 1 embedding layer
and 1 convolutional layer. and 1 convolutional layer.
The best WER using beam search with beam size 4 is: The best WER using modified beam search with beam size 4 is:
| | test-clean | test-other | | | test-clean | test-other |
|-----|------------|------------| |-----|------------|------------|
| WER | 2.68 | 6.72 | | WER | 2.67 | 6.64 |
Note: No auxiliary losses are used in the training and no LMs are used Note: No auxiliary losses are used in the training and no LMs are used
in the decoding. in the decoding.
We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Rc4Is-3Yp9LbcEz_Iy8hfyenyHsyjvqE?usp=sharing) We provide a Colab notebook to run a pre-trained transducer conformer + stateless decoder model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1CO1bXJ-2khDckZIW8zjOPHGSKLHpTDlp?usp=sharing)
### Aishell ### Aishell

View File

@ -82,17 +82,17 @@ class Decoder(nn.Module):
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embeding_out = F.pad( embedding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0) embedding_out, pad=(self.context_size - 1, 0)
) )
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
assert embeding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out) embedding_out = self.conv(embedding_out)
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
return embeding_out return embedding_out

View File

@ -48,6 +48,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by
import argparse import argparse
import logging import logging
import math import math
from typing import List
from pathlib import Path from pathlib import Path
from typing import List
import kaldifeat import kaldifeat
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
from conformer import Conformer from conformer import Conformer
@ -57,10 +58,10 @@ from joiner import Joiner
from model import Transducer from model import Transducer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
from icefall.lexicon import Lexicon
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -150,7 +151,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -204,7 +204,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
@ -219,7 +219,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -229,7 +229,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -237,7 +237,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -4,62 +4,73 @@
#### Conformer encoder + embedding decoder #### Conformer encoder + embedding decoder
Using commit `4c1b3665ee6efb935f4dd93a80ff0e154b13efb6`. Using commit `a8150021e01d34ecbd6198fe03a57eacf47a16f2`.
Conformer encoder + non-current decoder. The decoder Conformer encoder + non-recurrent decoder. The decoder
contains only an embedding layer and a Conv1d (with kernel size 2). contains only an embedding layer and a Conv1d (with kernel size 2).
The WERs are The WERs are
| | test-clean | test-other | comment | | | test-clean | test-other | comment |
|---------------------------|------------|------------|------------------------------------------| |-------------------------------------|------------|------------|------------------------------------------|
| greedy search | 2.69 | 6.81 | --epoch 71, --avg 15, --max-duration 100 | | greedy search (max sym per frame 1) | 2.68 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
| beam search (beam size 4) | 2.68 | 6.72 | --epoch 71, --avg 15, --max-duration 100 | | greedy search (max sym per frame 2) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
| greedy search (max sym per frame 3) | 2.69 | 6.71 | --epoch 61, --avg 18, --max-duration 100 |
| modified beam search (beam size 4) | 2.67 | 6.64 | --epoch 61, --avg 18, --max-duration 100 |
The training command for reproducing is given below: The training command for reproducing is given below:
``` ```
cd egs/librispeech/ASR/
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_stateless/train.py \ ./transducer_stateless/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 76 \ --num-epochs 76 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp-full \ --exp-dir transducer_stateless/exp-full \
--full-libri 1 \ --full-libri 1 \
--max-duration 250 \ --max-duration 300 \
--lr-factor 3 --lr-factor 5 \
--bpe-model data/lang_bpe_500/bpe.model \
--modified-transducer-prob 0.25
``` ```
The tensorboard training log can be found at The tensorboard training log can be found at
<https://tensorboard.dev/experiment/qGdqzHnxS0WJ695OXfZDzA/#scalars&_smoothingWeight=0> <https://tensorboard.dev/experiment/qgvWkbF2R46FYA6ZMNmOjA/#scalars>
The decoding command is: The decoding command is:
``` ```
epoch=71 epoch=61
avg=15 avg=18
## greedy search ## greedy search
./transducer_stateless/decode.py \ for sym in 1 2 3; do
--epoch $epoch \ ./transducer_stateless/decode.py \
--avg $avg \ --epoch $epoch \
--exp-dir transducer_stateless/exp-full \ --avg $avg \
--bpe-model ./data/lang_bpe_500/bpe.model \ --exp-dir transducer_stateless/exp-full \
--max-duration 100 --bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--max-sym-per-frame $sym
done
## modified beam search
## beam search
./transducer_stateless/decode.py \ ./transducer_stateless/decode.py \
--epoch $epoch \ --epoch $epoch \
--avg $avg \ --avg $avg \
--exp-dir transducer_stateless/exp-full \ --exp-dir transducer_stateless/exp-full \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --context-size 2 \
--decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
``` ```
You can find a pretrained model by visiting You can find a pretrained model by visiting
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-01-10> <https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-bpe-500-2022-02-07>
#### Conformer encoder + LSTM decoder #### Conformer encoder + LSTM decoder

View File

@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -123,6 +124,15 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--num-decoder-layers",
type=int,
default=6,
help="""Number of decoder layer of transformer decoder.
Setting this to 0 will not create the decoder at all (pure CTC model)
""",
)
parser.add_argument( parser.add_argument(
"--lr-factor", "--lr-factor",
type=float, type=float,
@ -210,7 +220,6 @@ def get_params() -> AttributeDict:
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_decoder_layers": 6,
# parameters for loss # parameters for loss
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
@ -357,9 +366,17 @@ def compute_loss(
supervisions, subsampling_factor=params.subsampling_factor supervisions, subsampling_factor=params.subsampling_factor
) )
token_ids = graph_compiler.texts_to_ids(texts) if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler):
# Works with a BPE model
decoding_graph = graph_compiler.compile(token_ids) token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
elif isinstance(graph_compiler, CtcTrainingGraphCompiler):
# Works with a phone lexicon
decoding_graph = graph_compiler.compile(texts)
else:
raise ValueError(
f"Unsupported type of graph compiler: {type(graph_compiler)}"
)
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(
nnet_output, nnet_output,
@ -584,12 +601,38 @@ def run(rank, world_size, args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
graph_compiler = BpeCtcTrainingGraphCompiler( if "lang_bpe" in params.lang_dir:
params.lang_dir, graph_compiler = BpeCtcTrainingGraphCompiler(
device=device, params.lang_dir,
sos_token="<sos/eos>", device=device,
eos_token="<sos/eos>", sos_token="<sos/eos>",
) eos_token="<sos/eos>",
)
elif "lang_phone" in params.lang_dir:
assert params.att_rate == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. Set --att-rate=0 "
"for pure CTC training when using a phone-based lang dir."
)
assert params.num_decoder_layers == 0, (
"Attention decoder training does not support phone lang dirs "
"at this time due to a missing <sos/eos> symbol. "
"Set --num-decoder-layers=0 for pure CTC training when using "
"a phone-based lang dir."
)
graph_compiler = CtcTrainingGraphCompiler(
lexicon,
device=device,
)
# Manually add the sos/eos ID with their default values
# from the BPE recipe which we're adapting here.
graph_compiler.sos_id = 1
graph_compiler.eos_id = 1
else:
raise ValueError(
f"Unsupported type of lang dir (we expected it to have "
f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}"
)
logging.info("About to create model") logging.info("About to create model")
model = Conformer( model = Conformer(
@ -607,7 +650,9 @@ def run(rank, world_size, args):
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) # Note: find_unused_parameters=True is needed in case we
# want to set params.att_rate = 0 (i.e. att decoder is not trained)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = Noam( optimizer = Noam(
model.parameters(), model.parameters(),

View File

@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos) decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0

View File

@ -99,6 +99,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded) decoder_out, _ = self.decoder(sos_y_padded)

View File

@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
device = model.device device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1) sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos) decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1) T = encoder_out.size(1)
t = 0 t = 0

View File

@ -101,6 +101,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=sos_id) sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded) decoder_out, _ = self.decoder(sos_y_padded)

View File

@ -17,7 +17,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
import torch import torch
from model import Transducer from model import Transducer
@ -48,7 +47,7 @@ def greedy_search(
device = model.device device = model.device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device [blank_id] * context_size, device=device, dtype=torch.int64
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
@ -108,8 +107,9 @@ class Hypothesis:
# Newly predicted tokens are appended to `ys`. # Newly predicted tokens are appended to `ys`.
ys: List[int] ys: List[int]
# The log prob of ys # The log prob of ys.
log_prob: float # It contains only one entry.
log_prob: torch.Tensor
@property @property
def key(self) -> str: def key(self) -> str:
@ -145,8 +145,10 @@ class HypothesisList(object):
""" """
key = hyp.key key = hyp.key
if key in self: if key in self:
old_hyp = self._data[key] old_hyp = self._data[key] # shallow copy
old_hyp.log_prob = np.logaddexp(old_hyp.log_prob, hyp.log_prob) torch.logaddexp(
old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob
)
else: else:
self._data[key] = hyp self._data[key] = hyp
@ -184,7 +186,7 @@ class HypothesisList(object):
assert key in self, f"{key} does not exist" assert key in self, f"{key} does not exist"
del self._data[key] del self._data[key]
def filter(self, threshold: float) -> "HypothesisList": def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold. """Remove all Hypotheses whose log_prob is less than threshold.
Caution: Caution:
@ -312,6 +314,113 @@ def run_joiner(
return log_prob return log_prob
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1)
B = HypothesisList()
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
A = list(B)
B = HypothesisList()
ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A])
# ys_log_probs is of shape (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyp in A],
device=device,
)
# decoder_input is of shape (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1)
log_probs.add_(ys_log_probs)
log_probs = log_probs.reshape(-1)
topk_log_probs, topk_indexes = log_probs.topk(beam)
# topk_hyp_indexes are indexes into `A`
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B.add(new_hyp)
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
def beam_search( def beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -351,7 +460,12 @@ def beam_search(
t = 0 t = 0
B = HypothesisList() B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0)) B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
max_sym_per_utt = 20000 max_sym_per_utt = 20000
@ -371,9 +485,6 @@ def beam_search(
joint_cache: Dict[str, torch.Tensor] = {} joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True: while True:
y_star = A.get_most_probable() y_star = A.get_most_probable()
A.remove(y_star) A.remove(y_star)
@ -396,18 +507,21 @@ def beam_search(
# First, process the blank symbol # First, process the blank symbol
skip_log_prob = log_prob[blank_id] skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item() new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys # ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob)) B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()): for idx in range(values.size(0)):
i = indices[idx].item()
if i == blank_id: if i == blank_id:
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob)) A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
# Check whether B contains more than "beam" elements more probable # Check whether B contains more than "beam" elements more probable

View File

@ -46,7 +46,7 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -104,6 +104,7 @@ def get_parser():
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
""", """,
) )
@ -111,7 +112,8 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --decoding-method is beam_search", help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -125,7 +127,8 @@ def get_parser():
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=3,
help="Maximum number of symbols per frame", help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
) )
return parser return parser
@ -256,6 +259,10 @@ def decode_one_batch(
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -389,11 +396,15 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
assert params.decoding_method in ("greedy_search", "beam_search") assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.decoding_method == "beam_search": if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += f"-beam-{params.beam_size}"
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"

View File

@ -75,24 +75,24 @@ class Decoder(nn.Module):
""" """
Args: Args:
y: y:
A 2-D tensor of shape (N, U) with blank prepended. A 2-D tensor of shape (N, U).
need_pad: need_pad:
True to left pad the input. Should be True during training. True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference. False to not pad the input. Should be False during inference.
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embeding_out = F.pad( embedding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0) embedding_out, pad=(self.context_size - 1, 0)
) )
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
assert embeding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out) embedding_out = self.conv(embedding_out)
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
return embeding_out return embedding_out

View File

@ -48,6 +48,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -62,6 +64,7 @@ class Transducer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
modified_transducer_prob: float = 0.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -73,6 +76,8 @@ class Transducer(nn.Module):
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
modified_transducer_prob:
The probability to use modified transducer loss.
Returns: Returns:
Return the transducer loss. Return the transducer loss.
""" """
@ -93,6 +98,7 @@ class Transducer(nn.Module):
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)
@ -113,6 +119,16 @@ class Transducer(nn.Module):
# reference stage # reference stage
import optimized_transducer import optimized_transducer
assert 0 <= modified_transducer_prob <= 1
if modified_transducer_prob == 0:
one_sym_per_frame = False
elif random.random() < modified_transducer_prob:
# random.random() returns a float in the range [0, 1)
one_sym_per_frame = True
else:
one_sym_per_frame = False
loss = optimized_transducer.transducer_loss( loss = optimized_transducer.transducer_loss(
logits=logits, logits=logits,
targets=y_padded, targets=y_padded,
@ -120,6 +136,8 @@ class Transducer(nn.Module):
target_lengths=y_lens, target_lengths=y_lens,
blank=blank_id, blank=blank_id,
reduction="sum", reduction="sum",
one_sym_per_frame=one_sym_per_frame,
from_log_softmax=False,
) )
return loss return loss

View File

@ -22,10 +22,11 @@ Usage:
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \ --method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav \
(1) beam search (2) beam search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
@ -34,6 +35,15 @@ Usage:
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav \
(3) modified beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./transducer_stateless/exp/epoch-xx.pt`. You can also use `./transducer_stateless/exp/epoch-xx.pt`.
Note: ./transducer_stateless/exp/pretrained.pt is generated by Note: ./transducer_stateless/exp/pretrained.pt is generated by
@ -49,8 +59,9 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -90,6 +101,7 @@ def get_parser():
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
""", """,
) )
@ -107,7 +119,7 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="Used only when --method is beam_search", help="Used only when --method is beam_search and modified_beam_search ",
) )
parser.add_argument( parser.add_argument(
@ -148,7 +160,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -162,7 +174,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -172,7 +184,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -180,7 +192,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
@ -217,6 +229,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -300,6 +313,10 @@ def main():
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")

View File

@ -138,6 +138,17 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--modified-transducer-prob",
type=float,
default=0.25,
help="""The probability to use modified transducer loss.
In modified transduer, it limits the maximum number of symbols
per frame to 1. See also the option --max-sym-per-frame in
transducer_stateless/decode.py
""",
)
return parser return parser
@ -213,7 +224,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
@ -228,7 +239,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -238,7 +249,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -246,7 +257,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
@ -383,7 +394,12 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y) loss = model(
x=feature,
x_lens=feature_lens,
y=y,
modified_transducer_prob=params.modified_transducer_prob,
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training

View File

@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object):
return decoding_graph return decoding_graph
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of word IDs.
Args:
texts:
It is a list of strings. Each string consists of space(s)
separated words. An example containing two strings is given below:
['HELLO ICEFALL', 'HELLO k2']
Returns:
Return a list-of-list of word IDs.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split():
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
return word_ids_list
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa: def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
"""Convert a list of transcript texts to an FsaVec. """Convert a list of transcript texts to an FsaVec.