update result

This commit is contained in:
Yuekai Zhang 2022-07-11 14:35:39 +00:00
parent dc40220951
commit a2b54cca10
5 changed files with 181 additions and 232 deletions

View File

@ -0,0 +1,19 @@
# Introduction
This recipe includes some different ASR models trained with Aishell2.
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner |
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

View File

@ -0,0 +1,72 @@
## Results
### Aishell2 char-based training results (Pruned Transducer 5)
#### 2022-07-11
Using the codes from this commit https://github.com/k2-fsa/icefall/pull/461.
When training with context size equals to 1, the WERs are
| | dev-ios | test-ios | comment |
|------------------------------------|-------|----------|----------------------------------|
| greedy search | | | --epoch 10, --avg 2, --max-duration 100 |
| modified beam search (beam size 4) | | | --epoch 10, --avg 2, --max-duration 100 |
| fast beam search (set as default) | | | --epoch 10, --avg 2, --max-duration 1500 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless5/train.py \
--world-size 4 \
--lang-dir data/lang_char \
--num-epochs 40 \
--start-epoch 1 \
--exp-dir /result \
--max-duration 300 \
--use-fp16 0 \
--num-encoder-layers 24 \
--dim-feedforward 1536 \
--nhead 8 \
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512
```
The decoding command is:
```
for method in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless5/decode.py \
--epoch 25 \
--avg 5 \
--exp-dir /result \
--max-duration 600 \
--decoding-method $method \
--max-sym-per-frame 1 \
--num-encoder-layers 24 \
--dim-feedforward 1536 \
--nhead 8 \
--encoder-dim 384 \
--decoder-dim 512 \
--joiner-dim 512 \
--context-size 1 \
--use-averaged-model True
done
```
The tensorboard training log can be found at
https:
A pre-trained model and decoding logs can be found at <https:>
When training with context size equals to 2, the WERs are
| | dev-ios | test-ios | comment |
|------------------------------------|-------|----------|----------------------------------|
| greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 |
| modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
| fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 |
The tensorboard training log can be found at
https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars

View File

@ -20,77 +20,32 @@
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 25 \
--avg 15 \ --avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --lang-dir data/lang_char \
--decoding-method greedy_search --max-duration 100 \
--decoding-method greedy_search
(2) beam search (not recommended) (2) modified beam search
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 25 \
--avg 15 \ --avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --lang-dir data/lang_char \
--decoding-method beam_search \ --max-duration 100 \
--beam-size 4 --decoding-method modified_beam_search \
--beam-size 4
(3) modified beam search (3) fast beam search
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--epoch 28 \ --epoch 25 \
--avg 15 \ --avg 5 \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \ --lang-dir data/lang_char \
--decoding-method modified_beam_search \ --max-duration 1500 \
--beam-size 4 --decoding-method fast_beam_search \
--beam 4 \
(4) fast beam search (one best) --max-contexts 4 \
./pruned_transducer_stateless5/decode.py \ --max-states 8
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless5/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -101,15 +56,11 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import AiShell2AsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
@ -184,17 +135,10 @@ def get_parser():
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=Path, type=Path,
default="data/lang_bpe_500", default="data/lang_char",
help="The lang dir containing word table and LG graph", help="The lang dir containing word table and LG graph",
) )
@ -268,7 +212,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=2, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
@ -306,9 +250,8 @@ def get_parser():
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, lexicon: Lexicon,
batch: dict, batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -326,8 +269,6 @@ def decode_one_batch(
It's the return value of :func:`get_params`. It's the return value of :func:`get_params`.
model: model:
The neural model. The neural model.
sp:
The BPE model.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -367,51 +308,8 @@ def decode_one_batch(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif ( elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
@ -421,8 +319,8 @@ def decode_one_batch(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -430,8 +328,8 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -455,7 +353,7 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyps.append([lexicon.token_table[idx] for idx in hyp])
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -478,8 +376,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, lexicon: Lexicon,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None, decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -491,10 +388,6 @@ def decode_dataset(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph: decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest, only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
@ -525,9 +418,8 @@ def decode_dataset(
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
model=model, model=model,
sp=sp, lexicon=lexicon,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
word_table=word_table,
batch=batch, batch=batch,
) )
@ -535,8 +427,7 @@ def decode_dataset(
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts): for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split() this_batch.append((ref_text, hyp_words))
this_batch.append((ref_words, hyp_words))
results[name].extend(this_batch) results[name].extend(this_batch)
@ -598,7 +489,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) AiShell2AsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -607,11 +498,7 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search",
"fast_beam_search", "fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
@ -650,13 +537,10 @@ def main():
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) params.blank_id = lexicon.token_table["<blk>"]
params.unk_id = lexicon.token_table["<unk>"]
# <blk> and <unk> are defined in local/train_bpe_model.py params.vocab_size = max(lexicon.tokens) + 1
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -744,45 +628,31 @@ def main():
model.eval() model.eval()
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else: else:
decoding_graph = None decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args) aishell2 = AiShell2AsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts() valid_cuts = aishell2.valid_cuts()
test_other_cuts = librispeech.test_other_cuts() test_cuts = aishell2.test_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) # use ios sets for dev and test
test_other_dl = librispeech.test_dataloaders(test_other_cuts) dev_dl = aishell2.valid_dataloaders(valid_cuts)
test_dl = aishell2.test_dataloaders(test_cuts)
test_sets = ["test-clean", "test-other"] test_sets = ["dev", "test"]
test_dl = [test_clean_dl, test_other_dl] test_dl = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,
model=model, model=model,
sp=sp, lexicon=lexicon,
word_table=word_table,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )

View File

@ -22,9 +22,9 @@
Usage: Usage:
./pruned_transducer_stateless5/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --lang-dir data/lang_char
--epoch 20 \ --epoch 25 \
--avg 10 --avg 5
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
@ -34,21 +34,20 @@ you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/aishell2/ASR
./pruned_transducer_stateless5/decode.py \ ./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search \ --decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model --lang-dir data/lang_char
""" """
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm
import torch import torch
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -58,6 +57,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon
from icefall.utils import str2bool from icefall.utils import str2bool
@ -115,10 +115,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_char",
help="Path to the BPE model", help="The lang dir",
) )
parser.add_argument( parser.add_argument(
@ -132,7 +132,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=2, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
@ -155,12 +155,10 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) params.blank_id = lexicon.token_table["<blk>"]
params.unk_id = lexicon.token_table["<unk>"]
# <blk> is defined in local/train_bpe_model.py params.vocab_size = max(lexicon.tokens) + 1
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -20,33 +20,24 @@ Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --lang-dir ./data/lang_char \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(2) beam search (2) modified beam search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --lang-dir ./data/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(4) fast beam search (3) fast beam search
./pruned_transducer_stateless5/pretrained.py \ ./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --lang-dir ./data/lang_char \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -66,7 +57,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -79,6 +69,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -95,9 +87,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--lang-dir",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to lang.
""",
) )
parser.add_argument( parser.add_argument(
@ -165,7 +158,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--context-size", "--context-size",
type=int, type=int,
default=2, default=1,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram", "2 means tri-gram",
) )
@ -216,13 +209,10 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() lexicon = Lexicon(params.lang_dir)
sp.load(params.bpe_model) params.blank_id = lexicon.token_table["<blk>"]
params.unk_id = lexicon.token_table["<unk>"]
# <blk> is defined in local/train_bpe_model.py params.vocab_size = max(lexicon.tokens) + 1
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")
@ -292,8 +282,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -302,16 +292,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for i in range(encoder_out.size(0)):
hyps.append(hyp.split()) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -332,11 +322,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append([lexicon.token_table[idx] for idx in hyp])
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) words = "".join(hyp)
s += f"{filename}:\n{words}\n\n" s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)