diff --git a/README.md b/README.md index 9f8db554c..be00eac50 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,14 @@ +## Introduction + +icefall contains ASR recipes for various datasets +using . + +You can use to deploy models +trained with icefall. + ## Installation Please refer to @@ -23,6 +31,8 @@ We provide the following recipes: - [Aidatatang_200zh][aidatatang_200zh] - [WenetSpeech][wenetspeech] - [Alimeeting][alimeeting] + - [Aishell4][aishell4] + - [TAL_CSASR][tal_csasr] ### yesno @@ -262,6 +272,36 @@ We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1tKr3f0mL17uO_ljdHGKtR7HOmthYHwJG?usp=sharing) +### Aishell4 + +We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][Aishell4_pruned_transducer_stateless5]. + +#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with all subsets) + +The best CER(%) results: +| | test | +|----------------------|--------| +| greedy search | 29.89 | +| fast beam search | 28.91 | +| modified beam search | 29.08 | + +We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1z3lkURVv9M7uTiIgf3Np9IntMHEknaks?usp=sharing) + +### TAL_CSASR + +We provide one model for this recipe: [Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss][TAL_CSASR_pruned_transducer_stateless5]. + +#### Pruned stateless RNN-T: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss + +The best CER(%) results: +| | dev | test | +|----------------------|------|------| +| greedy search | 7.30 | 7.39 | +| fast beam search | 7.15 | 7.22 | +| modified beam search | 7.18 | 7.26 | + +We provide a Colab notebook to run a pre-trained Pruned Transducer Stateless model: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1DmIx-NloI1CMU5GdZrlse7TRu4y3Dpf8?usp=sharing) + ## Deployment with C++ Once you have trained a model in icefall, you may want to deploy it with C++, @@ -290,6 +330,8 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [Aidatatang_200zh_pruned_transducer_stateless2]: egs/aidatatang_200zh/ASR/pruned_transducer_stateless2 [WenetSpeech_pruned_transducer_stateless2]: egs/wenetspeech/ASR/pruned_transducer_stateless2 [Alimeeting_pruned_transducer_stateless2]: egs/alimeeting/ASR/pruned_transducer_stateless2 +[Aishell4_pruned_transducer_stateless5]: egs/aishell4/ASR/pruned_transducer_stateless5 +[TAL_CSASR_pruned_transducer_stateless5]: egs/tal_csasr/ASR/pruned_transducer_stateless5 [yesno]: egs/yesno/ASR [librispeech]: egs/librispeech/ASR [aishell]: egs/aishell/ASR @@ -299,5 +341,6 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [aidatatang_200zh]: egs/aidatatang_200zh/ASR [wenetspeech]: egs/wenetspeech/ASR [alimeeting]: egs/alimeeting/ASR +[aishell4]: egs/aishell4/ASR +[tal_csasr]: egs/tal_csasr/ASR [k2]: https://github.com/k2-fsa/k2 -) diff --git a/egs/tal_csasr/ASR/README.md b/egs/tal_csasr/ASR/README.md new file mode 100644 index 000000000..a705a2f44 --- /dev/null +++ b/egs/tal_csasr/ASR/README.md @@ -0,0 +1,19 @@ + +# Introduction + +This recipe includes some different ASR models trained with TAL_CSASR. + +[./RESULTS.md](./RESULTS.md) contains the latest results. + +# Transducers + +There are various folders containing the name `transducer` in this folder. +The following table lists the differences among them. + +| | Encoder | Decoder | Comment | +|---------------------------------------|---------------------|--------------------|-----------------------------| +| `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| + +The decoder in `transducer_stateless` is modified from the paper +[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). +We place an additional Conv1d layer right after the input embedding layer. diff --git a/egs/tal_csasr/ASR/RESULTS.md b/egs/tal_csasr/ASR/RESULTS.md new file mode 100644 index 000000000..b711fa82b --- /dev/null +++ b/egs/tal_csasr/ASR/RESULTS.md @@ -0,0 +1,80 @@ +## Results + +### TAL_CSASR Mix Chars and BPEs training results (Pruned Transducer Stateless5) + +#### 2022-06-22 + +Using the codes from this PR https://github.com/k2-fsa/icefall/pull/428. + +The WERs are + +|decoding-method | epoch(iter) | avg | dev | test | +|--|--|--|--|--| +|greedy_search | 30 | 24 | 7.49 | 7.58| +|modified_beam_search | 30 | 24 | 7.33 | 7.38| +|fast_beam_search | 30 | 24 | 7.32 | 7.42| +|greedy_search(use-averaged-model=True) | 30 | 24 | 7.30 | 7.39| +|modified_beam_search(use-averaged-model=True) | 30 | 24 | 7.15 | 7.22| +|fast_beam_search(use-averaged-model=True) | 30 | 24 | 7.18 | 7.26| +|greedy_search | 348000 | 30 | 7.46 | 7.54| +|modified_beam_search | 348000 | 30 | 7.24 | 7.36| +|fast_beam_search | 348000 | 30 | 7.25 | 7.39 | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5" + +./pruned_transducer_stateless5/train.py \ + --world-size 6 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless5/exp \ + --lang-dir data/lang_char \ + --max-duration 90 +``` + +The tensorboard training log can be found at +https://tensorboard.dev/experiment/KaACzXOVR0OM6cy0qbN5hw/#scalars + +The decoding command is: +``` +epoch=30 +avg=24 +use_average_model=True + +## greedy search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 800 \ + --use-averaged-model $use_average_model + +## modified beam search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 800 \ + --decoding-method modified_beam_search \ + --beam-size 4 \ + --use-averaged-model $use_average_model + +## fast beam search +./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --lang-dir ./data/lang_char \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + --use-averaged-model $use_average_model +``` + +A pre-trained model and decoding logs can be found at diff --git a/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py b/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py index 68f9e15b9..d7fd838f2 100644 --- a/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py +++ b/egs/tal_csasr/ASR/local/tokenize_with_bpe_model.py @@ -28,11 +28,12 @@ and generates the text_with_bpe. import argparse import logging -import re import sentencepiece as spm from tqdm import tqdm +from icefall.utils import tokenize_by_bpe_model + def get_parser(): parser = argparse.ArgumentParser( @@ -61,29 +62,6 @@ def get_parser(): return parser -def tokenize_by_bpe_model(sp, txt): - tokens = [] - # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: - # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) - pattern = re.compile(r"([\u4e00-\u9fff])") - # Example: - # txt = "你好 ITS'S OKAY 的" - # chars = ["你", "好", " ITS'S OKAY ", "的"] - chars = pattern.split(txt.upper()) - mix_chars = [w for w in chars if len(w.strip()) > 0] - for ch_or_w in mix_chars: - # ch_or_w is a single CJK charater(i.e., "你"), do nothing. - if pattern.fullmatch(ch_or_w) is not None: - tokens.append(ch_or_w) - # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), - # encode ch_or_w using bpe_model. - else: - for p in sp.encode_as_pieces(ch_or_w): - tokens.append(p) - - return tokens - - def main(): parser = get_parser() args = parser.parse_args() @@ -103,7 +81,7 @@ def main(): for i in tqdm(range(len(lines))): x = lines[i] txt_tokens = tokenize_by_bpe_model(sp, x) - new_line = " ".join(txt_tokens) + new_line = txt_tokens.replace("/", " ") new_lines.append(new_line) logging.info("Starting writing the text_with_bpe") diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 6ac3747e3..49bfb148b 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -314,7 +314,8 @@ class TAL_CSASRAsrDataModule: max_duration=self.args.max_duration, shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - buffer_size=30000, + num_cuts_for_bins_estimate=20000, + buffer_size=60000, drop_last=self.args.drop_last, ) else: diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py index 6f7231f4b..bf3917df0 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/conformer.py @@ -117,10 +117,7 @@ class Conformer(EncoderInterface): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) @@ -293,8 +290,10 @@ class ConformerEncoder(nn.Module): ) self.num_layers = num_layers + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers - self.aux_layers = set(aux_layers + [num_layers - 1]) + self.aux_layers = aux_layers + [num_layers - 1] num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( @@ -1154,7 +1153,7 @@ class RandomCombine(nn.Module): """ num_inputs = self.num_inputs assert len(inputs) == num_inputs - if not self.training: + if not self.training or torch.jit.is_scripting(): return inputs[-1] # Shape of weights: (*, num_inputs) @@ -1162,8 +1161,22 @@ class RandomCombine(nn.Module): num_frames = inputs[0].numel() // num_channels mod_inputs = [] - for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) + + if False: + # It throws the following error for torch 1.6.0 when using + # torch script. + # + # Expected integer literal for index. ModuleList/Sequential + # indexing is only supported with integer literals. Enumeration is + # supported, e.g. 'for index, v in enumerate(self): ...': + # for i in range(num_inputs - 1): + # mod_inputs.append(self.linear[i](inputs[i])) + assert False + else: + for i, linear in enumerate(self.linear): + if i < num_inputs - 1: + mod_inputs.append(linear(inputs[i])) + mod_inputs.append(inputs[num_inputs - 1]) ndim = inputs[0].ndim @@ -1181,11 +1194,13 @@ class RandomCombine(nn.Module): # ans: (num_frames, num_channels, 1) ans = torch.matmul(stacked_inputs, weights) # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) + ans = ans.reshape(inputs[0].shape[:-1] + [num_channels]) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) return ans def _get_random_weights( diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index 6923298b1..1ad6ed943 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -288,7 +288,8 @@ def decode_one_batch( chars = pattern.split(hyp.upper()) chars_new = [] for char in chars: - chars_new.extend(char.strip().split(" ")) + if char != "": + chars_new.extend(char.strip().split(" ")) hyps.append(chars_new) elif ( params.decoding_method == "greedy_search" @@ -304,7 +305,8 @@ def decode_one_batch( chars = pattern.split(hyp.upper()) chars_new = [] for char in chars: - chars_new.extend(char.strip().split(" ")) + if char != "": + chars_new.extend(char.strip().split(" ")) hyps.append(chars_new) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( @@ -318,7 +320,8 @@ def decode_one_batch( chars = pattern.split(hyp.upper()) chars_new = [] for char in chars: - chars_new.extend(char.strip().split(" ")) + if char != "": + chars_new.extend(char.strip().split(" ")) hyps.append(chars_new) else: batch_size = encoder_out.size(0) @@ -350,7 +353,8 @@ def decode_one_batch( chars = pattern.split(hyp.upper()) chars_new = [] for char in chars: - chars_new.extend(char.strip().split(" ")) + if char != "": + chars_new.extend(char.strip().split(" ")) hyps.append(chars_new) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -415,7 +419,8 @@ def decode_dataset( chars = pattern.split(text.upper()) chars_new = [] for char in chars: - chars_new.extend(char.strip().split(" ")) + if char != "": + chars_new.extend(char.strip().split(" ")) texts[i] = chars_new hyps_dict = decode_one_batch( params=params, @@ -648,7 +653,7 @@ def main(): dev_cuts = tal_csasr.valid_cuts() dev_cuts = dev_cuts.map(text_normalize_for_cut) - dev_dl = tal_csasr.valid_dataloader(dev_cuts) + dev_dl = tal_csasr.valid_dataloaders(dev_cuts) test_cuts = tal_csasr.test_cuts() test_cuts = test_cuts.map(text_normalize_for_cut) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py index f1269a4bd..8f900208a 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# 2022 Xiaomi Corporation (Author: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -22,9 +23,10 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 + --lang-dir ./data/lang_char \ + --epoch 30 \ + --avg 24 \ + --use-averaged-model True It will generate a file exp_dir/pretrained.pt @@ -34,14 +36,14 @@ you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt - cd /path/to/egs/librispeech/ASR + cd /path/to/egs/tal_csasr/ASR ./pruned_transducer_stateless5/decode.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ + --epoch 30 \ + --avg 24 \ + --max-duration 800 \ --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --lang-dir ./data/lang_char """ import argparse @@ -58,6 +60,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import str2bool @@ -115,10 +118,13 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, ) parser.add_argument( @@ -146,8 +152,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -157,12 +161,13 @@ def main(): logging.info(f"device: {device}") + bpe_model = params.lang_dir + "/bpe.model" sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + sp.load(bpe_model) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) @@ -252,6 +257,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py index 1e100fcbd..dbe213b24 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 Xiaomi Corp. (authors: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,34 +21,25 @@ Usage: (1) greedy search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ + --lang-dir ./data/lang_char \ + --decoding-method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav -(2) beam search +(2) modified beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ + --lang-dir ./data/lang_char \ + --decoding-method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ /path/to/bar.wav -(3) modified beam search +(3) fast beam search ./pruned_transducer_stateless5/pretrained.py \ --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless5/pretrained.py \ - --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ + --lang-dir ./data/lang_char \ + --decoding-method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ /path/to/bar.wav @@ -62,6 +54,7 @@ Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by import argparse import logging import math +import re from typing import List import k2 @@ -79,6 +72,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from icefall.lexicon import Lexicon + def get_parser(): parser = argparse.ArgumentParser( @@ -95,13 +90,17 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - help="""Path to bpe.model.""", + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, ) parser.add_argument( - "--method", + "--decoding-method", type=str, default="greedy_search", help="""Possible values are: @@ -216,13 +215,13 @@ def main(): params.update(vars(args)) + bpe_model = params.lang_dir + "/bpe.model" sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + sp.load(bpe_model) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_di = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 logging.info(f"{params}") @@ -281,6 +280,7 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + pattern = re.compile(r"([\u4e00-\u9fff])") if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -292,8 +292,14 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + chars = pattern.split(hyp.upper()) + chars_new = [] + for char in chars: + if char != "": + chars_new.extend(char.strip().split(" ")) + hyps.append(chars_new) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -301,17 +307,28 @@ def main(): encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + chars = pattern.split(hyp.upper()) + chars_new = [] + for char in chars: + if char != "": + chars_new.extend(char.strip().split(" ")) + hyps.append(chars_new) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for i in range(encoder_out.size(0)): + hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + chars = pattern.split(hyp.upper()) + chars_new = [] + for char in chars: + if char != "": + chars_new.extend(char.strip().split(" ")) + hyps.append(chars_new) else: for i in range(num_waves): # fmt: off @@ -332,7 +349,13 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyp = sp.decode([lexicon.token_table[idx] for idx in hyp]) + chars = pattern.split(hyp.upper()) + chars_new = [] + for char in chars: + if char != "": + chars_new.extend(char.strip().split(" ")) + hyps.append(chars_new) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index 86822a784..ca35eba45 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -954,20 +954,20 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds # - # Caution: There is a reason to select 18.0 here. Please see + # Caution: There is a reason to select 20.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - return 1.0 <= c.duration <= 18.0 + return 1.0 <= c.duration <= 20.0 def text_normalize_for_cut(c: Cut): # Text normalize for each sample text = c.supervisions[0].text text = text.strip("\n").strip("\t") text = text_normalize(text) - text = "/".join(tokenize_by_bpe_model(sp, text)) + text = tokenize_by_bpe_model(sp, text) c.supervisions[0].text = text return c diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py index dbd069f63..235160e14 100644 --- a/icefall/char_graph_compiler.py +++ b/icefall/char_graph_compiler.py @@ -80,7 +80,8 @@ class CharCtcTrainingGraphCompiler(object): return ids def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]: - """Convert a list of texts to a list-of-list of token IDs. + """Convert a list of texts (which include chars and bpes) + to a list-of-list of token IDs. Args: texts: diff --git a/icefall/utils.py b/icefall/utils.py index b38574f0c..c407e7a10 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -20,6 +20,7 @@ import argparse import collections import logging import os +import re import subprocess from collections import defaultdict from contextlib import contextmanager @@ -30,6 +31,7 @@ from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 import k2.version import kaldialign +import sentencepiece as spm import torch import torch.distributed as dist import torch.nn as nn @@ -799,3 +801,40 @@ def optim_step_and_measure_param_change( delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) relative_change[n] = delta.item() return relative_change + + +def tokenize_by_bpe_model( + sp: spm.SentencePieceProcessor, + txt: str, +) -> str: + """ + Tokenize text with bpe model. This function is from + https://github1s.com/wenet-e2e/wenet/blob/main/wenet/dataset/processor.py#L322-L342. + Args: + sp: spm.SentencePieceProcessor. + txt: str + + Return: + A new string which includes chars and bpes. + """ + tokens = [] + # CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + pattern = re.compile(r"([\u4e00-\u9fff])") + # Example: + # txt = "你好 ITS'S OKAY 的" + # chars = ["你", "好", " ITS'S OKAY ", "的"] + chars = pattern.split(txt.upper()) + mix_chars = [w for w in chars if len(w.strip()) > 0] + for ch_or_w in mix_chars: + # ch_or_w is a single CJK charater(i.e., "你"), do nothing. + if pattern.fullmatch(ch_or_w) is not None: + tokens.append(ch_or_w) + # ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "), + # encode ch_or_w using bpe_model. + else: + for p in sp.encode_as_pieces(ch_or_w): + tokens.append(p) + txt_with_bpe = "/".join(tokens) + + return txt_with_bpe