From 9cc3f61056023086639768c7550ee22696c284ad Mon Sep 17 00:00:00 2001
From: luomingshuang <739314837@qq.com>
Date: Thu, 23 Jun 2022 16:24:56 +0800
Subject: [PATCH] do some changes for merging
---
README.md | 45 +++++++++-
egs/tal_csasr/ASR/README.md | 19 ++++
egs/tal_csasr/ASR/RESULTS.md | 80 +++++++++++++++++
.../ASR/local/tokenize_with_bpe_model.py | 28 +-----
.../asr_datamodule.py | 3 +-
.../pruned_transducer_stateless5/conformer.py | 39 +++++---
.../pruned_transducer_stateless5/decode.py | 17 ++--
.../pruned_transducer_stateless5/export.py | 44 +++++----
.../pretrained.py | 89 ++++++++++++-------
.../ASR/pruned_transducer_stateless5/train.py | 6 +-
icefall/char_graph_compiler.py | 3 +-
icefall/utils.py | 39 ++++++++
12 files changed, 313 insertions(+), 99 deletions(-)
create mode 100644 egs/tal_csasr/ASR/README.md
create mode 100644 egs/tal_csasr/ASR/RESULTS.md
diff --git a/README.md b/README.md
index 9f8db554c..be00eac50 100644
--- a/README.md
+++ b/README.md
@@ -2,6 +2,14 @@
+## Introduction
+
+icefall contains ASR recipes for various datasets
+using .
+
+You can use to deploy models
+trained with icefall.
+
## Installation
Please refer to
@@ -23,6 +31,8 @@ We provide the following recipes:
- [Aidatatang_200zh][aidatatang_200zh]
- [WenetSpeech][wenetspeech]
- [Alimeeting][alimeeting]
+ - [Aishell4][aishell4]
+ - [TAL_CSASR][tal_csasr]
### yesno
@@ -262,6 +272,36 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder
We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing)
+### Aishell4
+
+We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5].
+
+#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets)
+
+The best CER(%) results:
+| | test |
+|----------------------|--------|
+| greedy search | 29.89 |
+| fast beam search | 28.91 |
+| modified beam search | 29.08 |
+
+We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing)
+
+### TAL_CSASR
+
+We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5].
+
+#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss
+
+The best CER(%) results:
+| | dev | test |
+|----------------------|------|------|
+| greedy search | 7.30 | 7.39 |
+| fast beam search | 7.15 | 7.22 |
+| modified beam search | 7.18 | 7.26 |
+
+We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing)
+
## Deployment with C++
Once you have trained a model in icefall, you may want to deploy it with C++,
@@ -290,6 +330,8 @@ Please see: [
diff --git a/egs/tal_csasr/ASR/README.md b/egs/tal_csasr/ASR/README.md
new file mode 100644
index 000000000..a705a2f44
--- /dev/null
+++ b/egs/tal_csasr/ASR/README.md
@@ -0,0 +1,19 @@
+
+# Introduction
+
+This recipe includes some different ASR models trained with TAL_CSASR.
+
+[./RESULTS.md](./RESULTS.md) contains the latest results.
+
+# Transducers
+
+There are various folders containing the name `transducer` in this folder.
+The following table lists the differences among them.
+
+| | Encoder | Decoder | Comment |
+|---------------------------------------|---------------------|--------------------|-----------------------------|
+| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner|
+
+The decoder in `transducer_stateless` is modified from the paper
+[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
+We place an additional Conv1d layer right after the input embedding layer.
diff --git a/egs/tal_csasr/ASR/RESULTS.md b/egs/tal_csasr/ASR/RESULTS.md
new file mode 100644
index 000000000..b711fa82b
--- /dev/null
+++ b/egs/tal_csasr/ASR/RESULTS.md
@@ -0,0 +1,80 @@
+## Results
+
+### TAL_CSASR Mix Chars and BPEs training results (Pruned Transducer Stateless5)
+
+#### 2022-06-22
+
+Using the codes from this PR https://github.com/k2-fsa/icefall/pull/428.
+
+The WERs are
+
+|decoding-method | epoch(iter) | avg | dev | test |
+|--|--|--|--|--|
+|greedy_search | 30 | 24 | 7.49 | 7.58|
+|modified_beam_search | 30 | 24 | 7.33 | 7.38|
+|fast_beam_search | 30 | 24 | 7.32 | 7.42|
+|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39|
+|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22|
+|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26|
+|greedy_search | 348000 | 30 | 7.46 | 7.54|
+|modified_beam_search | 348000 | 30 | 7.24 | 7.36|
+|fast_beam_search | 348000 | 30 | 7.25 | 7.39 |
+
+The training command for reproducing is given below:
+
+```
+export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5"
+
+./pruned_transducer_stateless5/train.py \
+ --world-size 6 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --lang-dir data/lang_char \
+ --max-duration 90
+```
+
+The tensorboard training log can be found at
+https://tensorboard.dev/experiment/KaACzXOVR0OM6cy0qbN5hw/#scalars
+
+The decoding command is:
+```
+epoch=30
+avg=24
+use_average_model=True
+
+## greedy search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 800 \
+ --use-averaged-model $use_average_model
+
+## modified beam search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 800 \
+ --decoding-method modified_beam_search \
+ --beam-size 4 \
+ --use-averaged-model $use_average_model
+
+## fast beam search
+./pruned_transducer_stateless5/decode.py \
+ --epoch $epoch \
+ --avg $avg \
+ --exp-dir ./pruned_transducer_stateless5/exp \
+ --lang-dir ./data/lang_char \
+ --max-duration 1500 \
+ --decoding-method fast_beam_search \
+ --beam 4 \
+ --max-contexts 4 \
+ --max-states 8 \
+ --use-averaged-model $use_average_model
+```
+
+A pre-trained model and decoding logs can be found at
diff --git a/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py b/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py
index 68f9e15b9..d7fd838f2 100644
--- a/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py
+++ b/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py
@@ -28,11 +28,12 @@ and generates the text_with_bpe.
import argparse
import logging
-import re
import sentencepiece as spm
from tqdm import tqdm
+from icefall.utils import tokenize_by_bpe_model
+
def get_parser():
parser = argparse.ArgumentParser(
@@ -61,29 +62,6 @@ def get_parser():
return parser
-def tokenize_by_bpe_model(sp, txt):
- tokens = []
- # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
- pattern = re.compile(r"([\u4e00-\u9fff])")
- # Example:
- # txt = "你好 ITS'S OKAY 的"
- # chars = ["你", "好", " ITS'S OKAY ", "的"]
- chars = pattern.split(txt.upper())
- mix_chars = [w for w in chars if len(w.strip()) > 0]
- for ch_or_w in mix_chars:
- # ch_or_w is a single CJK charater(i.e., "你"), do nothing.
- if pattern.fullmatch(ch_or_w) is not None:
- tokens.append(ch_or_w)
- # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
- # encode ch_or_w using bpe_model.
- else:
- for p in sp.encode_as_pieces(ch_or_w):
- tokens.append(p)
-
- return tokens
-
-
def main():
parser = get_parser()
args = parser.parse_args()
@@ -103,7 +81,7 @@ def main():
for i in tqdm(range(len(lines))):
x = lines[i]
txt_tokens = tokenize_by_bpe_model(sp, x)
- new_line = " ".join(txt_tokens)
+ new_line = txt_tokens.replace("/", " ")
new_lines.append(new_line)
logging.info("Starting writing the text_with_bpe")
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
index 6ac3747e3..49bfb148b 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py
@@ -314,7 +314,8 @@ class TAL_CSASRAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
- buffer_size=30000,
+ num_cuts_for_bins_estimate=20000,
+ buffer_size=60000,
drop_last=self.args.drop_last,
)
else:
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py
index 6f7231f4b..bf3917df0 100644
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py
@@ -117,10 +117,7 @@ class Conformer(EncoderInterface):
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- # Caution: We assume the subsampling factor is 4!
- lengths = ((x_lens - 1) // 2 - 1) // 2
+ lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
@@ -293,8 +290,10 @@ class ConformerEncoder(nn.Module):
)
self.num_layers = num_layers
+ assert len(set(aux_layers)) == len(aux_layers)
+
assert num_layers - 1 not in aux_layers
- self.aux_layers = set(aux_layers + [num_layers - 1])
+ self.aux_layers = aux_layers + [num_layers - 1]
num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine(
@@ -1154,7 +1153,7 @@ class RandomCombine(nn.Module):
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
- if not self.training:
+ if not self.training or torch.jit.is_scripting():
return inputs[-1]
# Shape of weights: (*, num_inputs)
@@ -1162,8 +1161,22 @@ class RandomCombine(nn.Module):
num_frames = inputs[0].numel() // num_channels
mod_inputs = []
- for i in range(num_inputs - 1):
- mod_inputs.append(self.linear[i](inputs[i]))
+
+ if False:
+ # It throws the following error for torch 1.6.0 when using
+ # torch script.
+ #
+ # Expected integer literal for index. ModuleList/Sequential
+ # indexing is only supported with integer literals. Enumeration is
+ # supported, e.g. 'for index, v in enumerate(self): ...':
+ # for i in range(num_inputs - 1):
+ # mod_inputs.append(self.linear[i](inputs[i]))
+ assert False
+ else:
+ for i, linear in enumerate(self.linear):
+ if i < num_inputs - 1:
+ mod_inputs.append(linear(inputs[i]))
+
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim
@@ -1181,11 +1194,13 @@ class RandomCombine(nn.Module):
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
- ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
- if __name__ == "__main__":
- # for testing only...
- print("Weights = ", weights.reshape(num_frames, num_inputs))
+ ans = ans.reshape(inputs[0].shape[:-1] + [num_channels])
+
+ # The following if causes errors for torch script in torch 1.6.0
+ # if __name__ == "__main__":
+ # # for testing only...
+ # print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
index 6923298b1..1ad6ed943 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py
@@ -288,7 +288,8 @@ def decode_one_batch(
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
- chars_new.extend(char.strip().split(" "))
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
elif (
params.decoding_method == "greedy_search"
@@ -304,7 +305,8 @@ def decode_one_batch(
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
- chars_new.extend(char.strip().split(" "))
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
@@ -318,7 +320,8 @@ def decode_one_batch(
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
- chars_new.extend(char.strip().split(" "))
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
else:
batch_size = encoder_out.size(0)
@@ -350,7 +353,8 @@ def decode_one_batch(
chars = pattern.split(hyp.upper())
chars_new = []
for char in chars:
- chars_new.extend(char.strip().split(" "))
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
hyps.append(chars_new)
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
@@ -415,7 +419,8 @@ def decode_dataset(
chars = pattern.split(text.upper())
chars_new = []
for char in chars:
- chars_new.extend(char.strip().split(" "))
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
texts[i] = chars_new
hyps_dict = decode_one_batch(
params=params,
@@ -648,7 +653,7 @@ def main():
dev_cuts = tal_csasr.valid_cuts()
dev_cuts = dev_cuts.map(text_normalize_for_cut)
- dev_dl = tal_csasr.valid_dataloader(dev_cuts)
+ dev_dl = tal_csasr.valid_dataloaders(dev_cuts)
test_cuts = tal_csasr.test_cuts()
test_cuts = test_cuts.map(text_normalize_for_cut)
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
index f1269a4bd..8f900208a 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+# 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -22,9 +23,10 @@
Usage:
./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp \
- --bpe-model data/lang_bpe_500/bpe.model \
- --epoch 20 \
- --avg 10
+ --lang-dir ./data/lang_char \
+ --epoch 30 \
+ --avg 24 \
+ --use-averaged-model True
It will generate a file exp_dir/pretrained.pt
@@ -34,14 +36,14 @@ you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
- cd /path/to/egs/librispeech/ASR
+ cd /path/to/egs/tal_csasr/ASR
./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless5/exp \
- --epoch 9999 \
- --avg 1 \
- --max-duration 600 \
+ --epoch 30 \
+ --avg 24 \
+ --max-duration 800 \
--decoding-method greedy_search \
- --bpe-model data/lang_bpe_500/bpe.model
+ --lang-dir ./data/lang_char
"""
import argparse
@@ -58,6 +60,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
+from icefall.lexicon import Lexicon
from icefall.utils import str2bool
@@ -115,10 +118,13 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--lang-dir",
type=str,
- default="data/lang_bpe_500/bpe.model",
- help="Path to the BPE model",
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
)
parser.add_argument(
@@ -146,8 +152,6 @@ def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
- assert args.jit is False, "Support torchscript will be added later"
-
params = get_params()
params.update(vars(args))
@@ -157,12 +161,13 @@ def main():
logging.info(f"device: {device}")
+ bpe_model = params.lang_dir + "/bpe.model"
sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
+ sp.load(bpe_model)
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
@@ -252,6 +257,11 @@ def main():
model.eval()
if params.jit:
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
index 1e100fcbd..dbe213b24 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+# 2022 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -20,34 +21,25 @@ Usage:
(1) greedy search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method greedy_search \
+ --lang-dir ./data/lang_char \
+ --decoding-method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
-(2) beam search
+(2) modified beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method beam_search \
+ --lang-dir ./data/lang_char \
+ --decoding-method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
-(3) modified beam search
+(3) fast beam search
./pruned_transducer_stateless5/pretrained.py \
--checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method modified_beam_search \
- --beam-size 4 \
- /path/to/foo.wav \
- /path/to/bar.wav
-
-(4) fast beam search
-./pruned_transducer_stateless5/pretrained.py \
- --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \
- --bpe-model ./data/lang_bpe_500/bpe.model \
- --method fast_beam_search \
+ --lang-dir ./data/lang_char \
+ --decoding-method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
@@ -62,6 +54,7 @@ Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by
import argparse
import logging
import math
+import re
from typing import List
import k2
@@ -79,6 +72,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
+from icefall.lexicon import Lexicon
+
def get_parser():
parser = argparse.ArgumentParser(
@@ -95,13 +90,17 @@ def get_parser():
)
parser.add_argument(
- "--bpe-model",
+ "--lang-dir",
type=str,
- help="""Path to bpe.model.""",
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
)
parser.add_argument(
- "--method",
+ "--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
@@ -216,13 +215,13 @@ def main():
params.update(vars(args))
+ bpe_model = params.lang_dir + "/bpe.model"
sp = spm.SentencePieceProcessor()
- sp.load(params.bpe_model)
+ sp.load(bpe_model)
- # is defined in local/train_bpe_model.py
- params.blank_id = sp.piece_to_id("")
- params.unk_id = sp.piece_to_id("")
- params.vocab_size = sp.get_piece_size()
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_di = lexicon.token_table[""]
+ params.vocab_size = max(lexicon.tokens) + 1
logging.info(f"{params}")
@@ -281,6 +280,7 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
+ pattern = re.compile(r"([\u4e00-\u9fff])")
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@@ -292,8 +292,14 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
- for hyp in sp.decode(hyp_tokens):
- hyps.append(hyp.split())
+ for i in range(encoder_out.size(0)):
+ hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ chars = pattern.split(hyp.upper())
+ chars_new = []
+ for char in chars:
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
+ hyps.append(chars_new)
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@@ -301,17 +307,28 @@ def main():
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
-
- for hyp in sp.decode(hyp_tokens):
- hyps.append(hyp.split())
+ for i in range(encoder_out.size(0)):
+ hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ chars = pattern.split(hyp.upper())
+ chars_new = []
+ for char in chars:
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
+ hyps.append(chars_new)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
- for hyp in sp.decode(hyp_tokens):
- hyps.append(hyp.split())
+ for i in range(encoder_out.size(0)):
+ hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
+ chars = pattern.split(hyp.upper())
+ chars_new = []
+ for char in chars:
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
+ hyps.append(chars_new)
else:
for i in range(num_waves):
# fmt: off
@@ -332,7 +349,13 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
- hyps.append(sp.decode(hyp).split())
+ hyp = sp.decode([lexicon.token_table[idx] for idx in hyp])
+ chars = pattern.split(hyp.upper())
+ chars_new = []
+ for char in chars:
+ if char != "":
+ chars_new.extend(char.strip().split(" "))
+ hyps.append(chars_new)
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
index 86822a784..ca35eba45 100755
--- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
+++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py
@@ -954,20 +954,20 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
- # Caution: There is a reason to select 18.0 here. Please see
+ # Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
- return 1.0 <= c.duration <= 18.0
+ return 1.0 <= c.duration <= 20.0
def text_normalize_for_cut(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
text = text_normalize(text)
- text = "/".join(tokenize_by_bpe_model(sp, text))
+ text = tokenize_by_bpe_model(sp, text)
c.supervisions[0].text = text
return c
diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py
index dbd069f63..235160e14 100644
--- a/icefall/char_graph_compiler.py
+++ b/icefall/char_graph_compiler.py
@@ -80,7 +80,8 @@ class CharCtcTrainingGraphCompiler(object):
return ids
def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]:
- """Convert a list of texts to a list-of-list of token IDs.
+ """Convert a list of texts (which include chars and bpes)
+ to a list-of-list of token IDs.
Args:
texts:
diff --git a/icefall/utils.py b/icefall/utils.py
index b38574f0c..c407e7a10 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -20,6 +20,7 @@ import argparse
import collections
import logging
import os
+import re
import subprocess
from collections import defaultdict
from contextlib import contextmanager
@@ -30,6 +31,7 @@ from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2
import k2.version
import kaldialign
+import sentencepiece as spm
import torch
import torch.distributed as dist
import torch.nn as nn
@@ -799,3 +801,40 @@ def optim_step_and_measure_param_change(
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item()
return relative_change
+
+
+def tokenize_by_bpe_model(
+ sp: spm.SentencePieceProcessor,
+ txt: str,
+) -> str:
+ """
+ Tokenize text with bpe model. This function is from
+ https://github1s.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py#L322-L342.
+ Args:
+ sp: spm.SentencePieceProcessor.
+ txt: str
+
+ Return:
+ A new string which includes chars and bpes.
+ """
+ tokens = []
+ # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ pattern = re.compile(r"([\u4e00-\u9fff])")
+ # Example:
+ # txt = "你好 ITS'S OKAY 的"
+ # chars = ["你", "好", " ITS'S OKAY ", "的"]
+ chars = pattern.split(txt.upper())
+ mix_chars = [w for w in chars if len(w.strip()) > 0]
+ for ch_or_w in mix_chars:
+ # ch_or_w is a single CJK charater(i.e., "你"), do nothing.
+ if pattern.fullmatch(ch_or_w) is not None:
+ tokens.append(ch_or_w)
+ # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
+ # encode ch_or_w using bpe_model.
+ else:
+ for p in sp.encode_as_pieces(ch_or_w):
+ tokens.append(p)
+ txt_with_bpe = "/".join(tokens)
+
+ return txt_with_bpe