do some changes for merging

This commit is contained in:
luomingshuang 2022-06-23 16:24:56 +08:00
parent c1c893bd13
commit 9cc3f61056
12 changed files with 313 additions and 99 deletions

View File

@ -2,6 +2,14 @@
<img src="https://raw.githubusercontent.com/k2-fsa/icefall/master/docs/source/_static/logo.png" width=168>
</div>
## Introduction
icefall contains ASR recipes for various datasets
using <https://github.com/k2-fsa/k2>.
You can use <https://github.com/k2-fsa/sherpa> to deploy models
trained with icefall.
## Installation
Please refer to <https://icefall.readthedocs.io/en/latest/installation/index.html>
@ -23,6 +31,8 @@ We provide the following recipes:
- [Aidatatang_200zh][aidatatang_200zh]
- [WenetSpeech][wenetspeech]
- [Alimeeting][alimeeting]
- [Aishell4][aishell4]
- [TAL_CSASR][tal_csasr]
### yesno
@ -262,6 +272,36 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
### Aishell4
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
The best CER(%) results:
| | test |
|----------------------|--------|
| greedy search | 29.89 |
| fast beam search | 28.91 |
| modified beam search | 29.08 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
### TAL_CSASR
We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5].
#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
The best CER(%) results:
| | dev | test |
|----------------------|------|------|
| greedy search | 7.30 | 7.39 |
| fast beam search | 7.15 | 7.22 |
| modified beam search | 7.18 | 7.26 |
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
## Deployment with C++
Once you have trained a model in icefall, you may want to deploy it with C++,
@ -290,6 +330,8 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2
[WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2
[Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2
[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5
[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5
[yesno]: egs/yesno/ASR
[librispeech]: egs/librispeech/ASR
[aishell]: egs/aishell/ASR
@ -299,5 +341,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad
[aidatatang_200zh]: egs/aidatatang_200zh/ASR
[wenetspeech]: egs/wenetspeech/ASR
[alimeeting]: egs/alimeeting/ASR
[aishell4]: egs/aishell4/ASR
[tal_csasr]: egs/tal_csasr/ASR
[k2]: https://github.com/k2-fsa/k2
)

View File

@ -0,0 +1,19 @@
# Introduction
This recipe includes some different ASR models trained with TAL_CSASR.
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder.
The following table lists the differences among them.
| | Encoder | Decoder | Comment |
|---------------------------------------|---------------------|--------------------|-----------------------------|
| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.

View File

@ -0,0 +1,80 @@
## Results
### TAL_CSASR Mix Chars and BPEs training results (Pruned Transducer Stateless5)
#### 2022-06-22
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/428.
The WERs are
|decoding-method | epoch(iter) | avg | dev | test |
|--|--|--|--|--|
|greedy_search | 30 | 24 | 7.49 | 7.58|
|modified_beam_search | 30 | 24 | 7.33 | 7.38|
|fast_beam_search | 30 | 24 | 7.32 | 7.42|
|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|
|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|
|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26|
|greedy_search | 348000 | 30 | 7.46 | 7.54|
|modified_beam_search | 348000 | 30 | 7.24 | 7.36|
|fast_beam_search | 348000 | 30 | 7.25 | 7.39 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5"
./pruned_transducer_stateless5/train.py \
--world-size 6 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir data/lang_char \
--max-duration 90
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/KaACzXOVR0OM6cy0qbN5hw/#scalars
The decoding command is:
```
epoch=30
avg=24
use_average_model=True
## greedy search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--use-averaged-model $use_average_model
## modified beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 800 \
--decoding-method modified_beam_search \
--beam-size 4 \
--use-averaged-model $use_average_model
## fast beam search
./pruned_transducer_stateless5/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless5/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
--use-averaged-model $use_average_model
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tal-csasr_pruned_transducer_stateless5>

View File

@ -28,11 +28,12 @@ and generates the text_with_bpe.
import argparse
import logging
import re
import sentencepiece as spm
from tqdm import tqdm
from icefall.utils import tokenize_by_bpe_model
def get_parser():
parser = argparse.ArgumentParser(
@ -61,29 +62,6 @@ def get_parser():
return parser
def tokenize_by_bpe_model(sp, txt):
tokens = []
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(txt.upper())
mix_chars = [w for w in chars if len(w.strip()) > 0]
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
tokens.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
tokens.append(p)
return tokens
def main():
parser = get_parser()
args = parser.parse_args()
@ -103,7 +81,7 @@ def main():
for i in tqdm(range(len(lines))):
x = lines[i]
txt_tokens = tokenize_by_bpe_model(sp, x)
new_line = " ".join(txt_tokens)
new_line = txt_tokens.replace("/", " ")
new_lines.append(new_line)
logging.info("Starting writing the text_with_bpe")

View File

@ -314,7 +314,8 @@ class TAL_CSASRAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
buffer_size=30000,
num_cuts_for_bins_estimate=20000,
buffer_size=60000,
drop_last=self.args.drop_last,
)
else:

View File

@ -117,10 +117,7 @@ class Conformer(EncoderInterface):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
@ -293,8 +290,10 @@ class ConformerEncoder(nn.Module):
)
self.num_layers = num_layers
assert len(set(aux_layers)) == len(aux_layers)
assert num_layers - 1 not in aux_layers
self.aux_layers = set(aux_layers + [num_layers - 1])
self.aux_layers = aux_layers + [num_layers - 1]
num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine(
@ -1154,7 +1153,7 @@ class RandomCombine(nn.Module):
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training:
if not self.training or torch.jit.is_scripting():
return inputs[-1]
# Shape of weights: (*, num_inputs)
@ -1162,8 +1161,22 @@ class RandomCombine(nn.Module):
num_frames = inputs[0].numel() // num_channels
mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
if False:
# It throws the following error for torch 1.6.0 when using
# torch script.
#
# Expected integer literal for index. ModuleList/Sequential
# indexing is only supported with integer literals. Enumeration is
# supported, e.g. 'for index, v in enumerate(self): ...':
# for i in range(num_inputs - 1):
# mod_inputs.append(self.linear[i](inputs[i]))
assert False
else:
for i, linear in enumerate(self.linear):
if i < num_inputs - 1:
mod_inputs.append(linear(inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim
@ -1181,11 +1194,13 @@ class RandomCombine(nn.Module):
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
if __name__ == "__main__":
# for testing only...
print("Weights = ", weights.reshape(num_frames, num_inputs))
ans = ans.reshape(inputs[0].shape[:-1] + [num_channels])
# The following if causes errors for torch script in torch 1.6.0
# if __name__ == "__main__":
# # for testing only...
# print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(

View File

@ -288,7 +288,8 @@ def decode_one_batch(
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
chars_new.extend(char.strip().split(" "))
if char != "":
chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
elif (
params.decoding_method == "greedy_search"
@ -304,7 +305,8 @@ def decode_one_batch(
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
chars_new.extend(char.strip().split(" "))
if char != "":
chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
@ -318,7 +320,8 @@ def decode_one_batch(
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
chars_new.extend(char.strip().split(" "))
if char != "":
chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
else:
batch_size = encoder_out.size(0)
@ -350,7 +353,8 @@ def decode_one_batch(
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
chars_new.extend(char.strip().split(" "))
if char != "":
chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
@ -415,7 +419,8 @@ def decode_dataset(
chars = pattern.split(text.upper())
chars_new = []
for char in chars:
chars_new.extend(char.strip().split(" "))
if char != "":
chars_new.extend(char.strip().split(" "))
texts[i] = chars_new
hyps_dict = decode_one_batch(
params=params,
@ -648,7 +653,7 @@ def main():
dev_cuts = tal_csasr.valid_cuts()
dev_cuts = dev_cuts.map(text_normalize_for_cut)
dev_dl = tal_csasr.valid_dataloader(dev_cuts)
dev_dl = tal_csasr.valid_dataloaders(dev_cuts)
test_cuts = tal_csasr.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -22,9 +23,10 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
--lang-dir ./data/lang_char \
--epoch 30 \
--avg 24 \
--use-averaged-model True
It will generate a file exp_dir/pretrained.pt
@ -34,14 +36,14 @@ you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
cd /path/to/egs/tal_csasr/ASR
./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--epoch 30 \
--avg 24 \
--max-duration 800 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
--lang-dir ./data/lang_char
"""
import argparse
@ -58,6 +60,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
@ -115,10 +118,13 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
@ -146,8 +152,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params()
params.update(vars(args))
@ -157,12 +161,13 @@ def main():
logging.info(f"device: {device}")
bpe_model = params.lang_dir + "/bpe.model"
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
sp.load(bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@ -252,6 +257,11 @@ def main():
model.eval()
if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -20,34 +21,25 @@ Usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
--lang-dir ./data/lang_char \
--decoding-method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
(2) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--lang-dir ./data/lang_char \
--decoding-method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
(3) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method fast_beam_search \
--lang-dir ./data/lang_char \
--decoding-method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
@ -62,6 +54,7 @@ Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
import argparse
import logging
import math
import re
from typing import List
import k2
@ -79,6 +72,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
@ -95,13 +90,17 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--lang-dir",
type=str,
help="""Path to bpe.model.""",
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--method",
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
@ -216,13 +215,13 @@ def main():
params.update(vars(args))
bpe_model = params.lang_dir + "/bpe.model"
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
sp.load(bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
lexicon = Lexicon(params.lang_dir)
params.blank_di = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(f"{params}")
@ -281,6 +280,7 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
pattern = re.compile(r"([\u4e00-\u9fff])")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -292,8 +292,14 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for i in range(encoder_out.size(0)):
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
if char != "":
chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -301,17 +307,28 @@ def main():
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for i in range(encoder_out.size(0)):
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
if char != "":
chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for i in range(encoder_out.size(0)):
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
if char != "":
chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
else:
for i in range(num_waves):
# fmt: off
@ -332,7 +349,13 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp])
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
if char != "":
chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):

View File

@ -954,20 +954,20 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 18.0 here. Please see
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 18.0
return 1.0 <= c.duration <= 20.0
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
text = text_normalize(text)
text = "/".join(tokenize_by_bpe_model(sp, text))
text = tokenize_by_bpe_model(sp, text)
c.supervisions[0].text = text
return c

View File

@ -80,7 +80,8 @@ class CharCtcTrainingGraphCompiler(object):
return ids
def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of token IDs.
"""Convert a list of texts (which include chars and bpes)
to a list-of-list of token IDs.
Args:
texts:

View File

@ -20,6 +20,7 @@ import argparse
import collections
import logging
import os
import re
import subprocess
from collections import defaultdict
from contextlib import contextmanager
@ -30,6 +31,7 @@ from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2
import k2.version
import kaldialign
import sentencepiece as spm
import torch
import torch.distributed as dist
import torch.nn as nn
@ -799,3 +801,40 @@ def optim_step_and_measure_param_change(
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item()
return relative_change
def tokenize_by_bpe_model(
sp: spm.SentencePieceProcessor,
txt: str,
) -> str:
"""
Tokenize text with bpe model. This function is from
https://github1s.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py#L322-L342.
Args:
sp: spm.SentencePieceProcessor.
txt: str
Return:
A new string which includes chars and bpes.
"""
tokens = []
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern = re.compile(r"([\u4e00-\u9fff])")
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars = pattern.split(txt.upper())
mix_chars = [w for w in chars if len(w.strip()) > 0]
for ch_or_w in mix_chars:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if pattern.fullmatch(ch_or_w) is not None:
tokens.append(ch_or_w)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else:
for p in sp.encode_as_pieces(ch_or_w):
tokens.append(p)
txt_with_bpe = "/".join(tokens)
return txt_with_bpe