do some changes for merging
This commit is contained in:
parent
c1c893bd13
commit
9cc3f61056
45
README.md
45
README.md
@ -2,6 +2,14 @@
|
||||
<img src="https://raw.githubusercontent.com/k2-fsa/icefall/master/docs/source/_static/logo.png" width=168>
|
||||
</div>
|
||||
|
||||
## Introduction
|
||||
|
||||
icefall contains ASR recipes for various datasets
|
||||
using <https://github.com/k2-fsa/k2>.
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy models
|
||||
trained with icefall.
|
||||
|
||||
## Installation
|
||||
|
||||
Please refer to <https://icefall.readthedocs.io/en/latest/installation/index.html>
|
||||
@ -23,6 +31,8 @@ We provide the following recipes:
|
||||
- [Aidatatang_200zh][aidatatang_200zh]
|
||||
- [WenetSpeech][wenetspeech]
|
||||
- [Alimeeting][alimeeting]
|
||||
- [Aishell4][aishell4]
|
||||
- [TAL_CSASR][tal_csasr]
|
||||
|
||||
### yesno
|
||||
|
||||
@ -262,6 +272,36 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
|
||||
|
||||
### Aishell4
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
|
||||
|
||||
The best CER(%) results:
|
||||
| | test |
|
||||
|----------------------|--------|
|
||||
| greedy search | 29.89 |
|
||||
| fast beam search | 28.91 |
|
||||
| modified beam search | 29.08 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
|
||||
|
||||
### TAL_CSASR
|
||||
|
||||
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5].
|
||||
|
||||
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
|
||||
|
||||
The best CER(%) results:
|
||||
| | dev | test |
|
||||
|----------------------|------|------|
|
||||
| greedy search | 7.30 | 7.39 |
|
||||
| fast beam search | 7.15 | 7.22 |
|
||||
| modified beam search | 7.18 | 7.26 |
|
||||
|
||||
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
|
||||
|
||||
## Deployment with C++
|
||||
|
||||
Once you have trained a model in icefall, you may want to deploy it with C++,
|
||||
@ -290,6 +330,8 @@ Please see: [
|
||||
|
||||
19
egs/tal_csasr/ASR/README.md
Normal file
19
egs/tal_csasr/ASR/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes some different ASR models trained with TAL_CSASR.
|
||||
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||
|
||||
# Transducers
|
||||
|
||||
There are various folders containing the name `transducer` in this folder.
|
||||
The following table lists the differences among them.
|
||||
|
||||
| | Encoder | Decoder | Comment |
|
||||
|---------------------------------------|---------------------|--------------------|-----------------------------|
|
||||
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
|
||||
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
We place an additional Conv1d layer right after the input embedding layer.
|
||||
80
egs/tal_csasr/ASR/RESULTS.md
Normal file
80
egs/tal_csasr/ASR/RESULTS.md
Normal file
@ -0,0 +1,80 @@
|
||||
## Results
|
||||
|
||||
### TAL_CSASR Mix Chars and BPEs training results (Pruned Transducer Stateless5)
|
||||
|
||||
#### 2022-06-22
|
||||
|
||||
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/428.
|
||||
|
||||
The WERs are
|
||||
|
||||
|decoding-method | epoch(iter) | avg | dev | test |
|
||||
|--|--|--|--|--|
|
||||
|greedy_search | 30 | 24 | 7.49 | 7.58|
|
||||
|modified_beam_search | 30 | 24 | 7.33 | 7.38|
|
||||
|fast_beam_search | 30 | 24 | 7.32 | 7.42|
|
||||
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|
|
||||
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|
|
||||
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26|
|
||||
|greedy_search | 348000 | 30 | 7.46 | 7.54|
|
||||
|modified_beam_search | 348000 | 30 | 7.24 | 7.36|
|
||||
|fast_beam_search | 348000 | 30 | 7.25 | 7.39 |
|
||||
|
||||
The training command for reproducing is given below:
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5"
|
||||
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--world-size 6 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--lang-dir data/lang_char \
|
||||
--max-duration 90
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
https://tensorboard.dev/experiment/KaACzXOVR0OM6cy0qbN5hw/#scalars
|
||||
|
||||
The decoding command is:
|
||||
```
|
||||
epoch=30
|
||||
avg=24
|
||||
use_average_model=True
|
||||
|
||||
## greedy search
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--lang-dir ./data/lang_char \
|
||||
--max-duration 800 \
|
||||
--use-averaged-model $use_average_model
|
||||
|
||||
## modified beam search
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--lang-dir ./data/lang_char \
|
||||
--max-duration 800 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--use-averaged-model $use_average_model
|
||||
|
||||
## fast beam search
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--epoch $epoch \
|
||||
--avg $avg \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--lang-dir ./data/lang_char \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8 \
|
||||
--use-averaged-model $use_average_model
|
||||
```
|
||||
|
||||
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5>
|
||||
@ -28,11 +28,12 @@ and generates the text_with_bpe.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
|
||||
import sentencepiece as spm
|
||||
from tqdm import tqdm
|
||||
|
||||
from icefall.utils import tokenize_by_bpe_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -61,29 +62,6 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def tokenize_by_bpe_model(sp, txt):
|
||||
tokens = []
|
||||
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
# Example:
|
||||
# txt = "你好 ITS'S OKAY 的"
|
||||
# chars = ["你", "好", " ITS'S OKAY ", "的"]
|
||||
chars = pattern.split(txt.upper())
|
||||
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
||||
for ch_or_w in mix_chars:
|
||||
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
||||
if pattern.fullmatch(ch_or_w) is not None:
|
||||
tokens.append(ch_or_w)
|
||||
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
||||
# encode ch_or_w using bpe_model.
|
||||
else:
|
||||
for p in sp.encode_as_pieces(ch_or_w):
|
||||
tokens.append(p)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
@ -103,7 +81,7 @@ def main():
|
||||
for i in tqdm(range(len(lines))):
|
||||
x = lines[i]
|
||||
txt_tokens = tokenize_by_bpe_model(sp, x)
|
||||
new_line = " ".join(txt_tokens)
|
||||
new_line = txt_tokens.replace("/", " ")
|
||||
new_lines.append(new_line)
|
||||
|
||||
logging.info("Starting writing the text_with_bpe")
|
||||
|
||||
@ -314,7 +314,8 @@ class TAL_CSASRAsrDataModule:
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=30000,
|
||||
num_cuts_for_bins_estimate=20000,
|
||||
buffer_size=60000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -117,10 +117,7 @@ class Conformer(EncoderInterface):
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
@ -293,8 +290,10 @@ class ConformerEncoder(nn.Module):
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
assert len(set(aux_layers)) == len(aux_layers)
|
||||
|
||||
assert num_layers - 1 not in aux_layers
|
||||
self.aux_layers = set(aux_layers + [num_layers - 1])
|
||||
self.aux_layers = aux_layers + [num_layers - 1]
|
||||
|
||||
num_channels = encoder_layer.norm_final.num_channels
|
||||
self.combiner = RandomCombine(
|
||||
@ -1154,7 +1153,7 @@ class RandomCombine(nn.Module):
|
||||
"""
|
||||
num_inputs = self.num_inputs
|
||||
assert len(inputs) == num_inputs
|
||||
if not self.training:
|
||||
if not self.training or torch.jit.is_scripting():
|
||||
return inputs[-1]
|
||||
|
||||
# Shape of weights: (*, num_inputs)
|
||||
@ -1162,8 +1161,22 @@ class RandomCombine(nn.Module):
|
||||
num_frames = inputs[0].numel() // num_channels
|
||||
|
||||
mod_inputs = []
|
||||
for i in range(num_inputs - 1):
|
||||
mod_inputs.append(self.linear[i](inputs[i]))
|
||||
|
||||
if False:
|
||||
# It throws the following error for torch 1.6.0 when using
|
||||
# torch script.
|
||||
#
|
||||
# Expected integer literal for index. ModuleList/Sequential
|
||||
# indexing is only supported with integer literals. Enumeration is
|
||||
# supported, e.g. 'for index, v in enumerate(self): ...':
|
||||
# for i in range(num_inputs - 1):
|
||||
# mod_inputs.append(self.linear[i](inputs[i]))
|
||||
assert False
|
||||
else:
|
||||
for i, linear in enumerate(self.linear):
|
||||
if i < num_inputs - 1:
|
||||
mod_inputs.append(linear(inputs[i]))
|
||||
|
||||
mod_inputs.append(inputs[num_inputs - 1])
|
||||
|
||||
ndim = inputs[0].ndim
|
||||
@ -1181,11 +1194,13 @@ class RandomCombine(nn.Module):
|
||||
# ans: (num_frames, num_channels, 1)
|
||||
ans = torch.matmul(stacked_inputs, weights)
|
||||
# ans: (*, num_channels)
|
||||
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# for testing only...
|
||||
print("Weights = ", weights.reshape(num_frames, num_inputs))
|
||||
ans = ans.reshape(inputs[0].shape[:-1] + [num_channels])
|
||||
|
||||
# The following if causes errors for torch script in torch 1.6.0
|
||||
# if __name__ == "__main__":
|
||||
# # for testing only...
|
||||
# print("Weights = ", weights.reshape(num_frames, num_inputs))
|
||||
return ans
|
||||
|
||||
def _get_random_weights(
|
||||
|
||||
@ -288,7 +288,8 @@ def decode_one_batch(
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
hyps.append(chars_new)
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
@ -304,7 +305,8 @@ def decode_one_batch(
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
hyps.append(chars_new)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
@ -318,7 +320,8 @@ def decode_one_batch(
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
hyps.append(chars_new)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
@ -350,7 +353,8 @@ def decode_one_batch(
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
hyps.append(chars_new)
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
@ -415,7 +419,8 @@ def decode_dataset(
|
||||
chars = pattern.split(text.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
texts[i] = chars_new
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
@ -648,7 +653,7 @@ def main():
|
||||
|
||||
dev_cuts = tal_csasr.valid_cuts()
|
||||
dev_cuts = dev_cuts.map(text_normalize_for_cut)
|
||||
dev_dl = tal_csasr.valid_dataloader(dev_cuts)
|
||||
dev_dl = tal_csasr.valid_dataloaders(dev_cuts)
|
||||
|
||||
test_cuts = tal_csasr.test_cuts()
|
||||
test_cuts = test_cuts.map(text_normalize_for_cut)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -22,9 +23,10 @@
|
||||
Usage:
|
||||
./pruned_transducer_stateless5/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
--lang-dir ./data/lang_char \
|
||||
--epoch 30 \
|
||||
--avg 24 \
|
||||
--use-averaged-model True
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
@ -34,14 +36,14 @@ you can do:
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
cd /path/to/egs/tal_csasr/ASR
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--epoch 9999 \
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--epoch 30 \
|
||||
--avg 24 \
|
||||
--max-duration 800 \
|
||||
--decoding-method greedy_search \
|
||||
--bpe-model data/lang_bpe_500/bpe.model
|
||||
--lang-dir ./data/lang_char
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@ -58,6 +60,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
@ -115,10 +118,13 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
default="data/lang_char",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -146,8 +152,6 @@ def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
assert args.jit is False, "Support torchscript will be added later"
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
@ -157,12 +161,13 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
bpe_model = params.lang_dir + "/bpe.model"
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
sp.load(bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -252,6 +257,11 @@ def main():
|
||||
model.eval()
|
||||
|
||||
if params.jit:
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -20,34 +21,25 @@ Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
--lang-dir ./data/lang_char \
|
||||
--decoding-method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
(2) modified beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--lang-dir ./data/lang_char \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
(3) fast beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless5/pretrained.py \
|
||||
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method fast_beam_search \
|
||||
--lang-dir ./data/lang_char \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
@ -62,6 +54,7 @@ Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
@ -79,6 +72,8 @@ from beam_search import (
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -95,13 +90,17 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
default="data/lang_char",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
@ -216,13 +215,13 @@ def main():
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
bpe_model = params.lang_dir + "/bpe.model"
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
sp.load(bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_di = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
@ -281,6 +280,7 @@ def main():
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
@ -292,8 +292,14 @@ def main():
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
hyps.append(chars_new)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
@ -301,17 +307,28 @@ def main():
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
hyps.append(chars_new)
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
for i in range(encoder_out.size(0)):
|
||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
hyps.append(chars_new)
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
@ -332,7 +349,13 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp])
|
||||
chars = pattern.split(hyp.upper())
|
||||
chars_new = []
|
||||
for char in chars:
|
||||
if char != "":
|
||||
chars_new.extend(char.strip().split(" "))
|
||||
hyps.append(chars_new)
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
|
||||
@ -954,20 +954,20 @@ def run(rank, world_size, args):
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 18.0 here. Please see
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
return 1.0 <= c.duration <= 18.0
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
def text_normalize_for_cut(c: Cut):
|
||||
# Text normalize for each sample
|
||||
text = c.supervisions[0].text
|
||||
text = text.strip("\n").strip("\t")
|
||||
text = text_normalize(text)
|
||||
text = "/".join(tokenize_by_bpe_model(sp, text))
|
||||
text = tokenize_by_bpe_model(sp, text)
|
||||
c.supervisions[0].text = text
|
||||
return c
|
||||
|
||||
|
||||
@ -80,7 +80,8 @@ class CharCtcTrainingGraphCompiler(object):
|
||||
return ids
|
||||
|
||||
def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of token IDs.
|
||||
"""Convert a list of texts (which include chars and bpes)
|
||||
to a list-of-list of token IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
|
||||
@ -20,6 +20,7 @@ import argparse
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
@ -30,6 +31,7 @@ from typing import Dict, Iterable, List, TextIO, Tuple, Union
|
||||
import k2
|
||||
import k2.version
|
||||
import kaldialign
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
@ -799,3 +801,40 @@ def optim_step_and_measure_param_change(
|
||||
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
|
||||
relative_change[n] = delta.item()
|
||||
return relative_change
|
||||
|
||||
|
||||
def tokenize_by_bpe_model(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
txt: str,
|
||||
) -> str:
|
||||
"""
|
||||
Tokenize text with bpe model. This function is from
|
||||
https://github1s.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py#L322-L342.
|
||||
Args:
|
||||
sp: spm.SentencePieceProcessor.
|
||||
txt: str
|
||||
|
||||
Return:
|
||||
A new string which includes chars and bpes.
|
||||
"""
|
||||
tokens = []
|
||||
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
pattern = re.compile(r"([\u4e00-\u9fff])")
|
||||
# Example:
|
||||
# txt = "你好 ITS'S OKAY 的"
|
||||
# chars = ["你", "好", " ITS'S OKAY ", "的"]
|
||||
chars = pattern.split(txt.upper())
|
||||
mix_chars = [w for w in chars if len(w.strip()) > 0]
|
||||
for ch_or_w in mix_chars:
|
||||
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
|
||||
if pattern.fullmatch(ch_or_w) is not None:
|
||||
tokens.append(ch_or_w)
|
||||
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
|
||||
# encode ch_or_w using bpe_model.
|
||||
else:
|
||||
for p in sp.encode_as_pieces(ch_or_w):
|
||||
tokens.append(p)
|
||||
txt_with_bpe = "/".join(tokens)
|
||||
|
||||
return txt_with_bpe
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user