diff --git a/egs/multi_zh_en/ASR/prepare.sh b/egs/multi_zh_en/ASR/prepare.sh index 4f0d0d752..a6808363d 100755 --- a/egs/multi_zh_en/ASR/prepare.sh +++ b/egs/multi_zh_en/ASR/prepare.sh @@ -13,6 +13,7 @@ dl_dir=$PWD/download . shared/parse_options.sh || exit 1 vocab_sizes=( + 500 2000 ) @@ -76,25 +77,31 @@ fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Prepare Byte BPE based lang" mkdir -p data/fbank - if [ ! -d ../../aishell2/ASR/data/lang_char ]; then + if [ ! -d ../../aishell2/ASR/data/lang_char ] && [ ! -d ./data/lang_char ]; then log "Abort! Please run ../../aishell2/ASR/prepare.sh --stage 3 --stop-stage 3" exit 1 fi - if [ ! -d ../../librispeech/ASR/data/lang_phone ]; then + if [ ! -d ../../librispeech/ASR/data/lang_phone ] && [ ! -d ./data/lang_phone ]; then log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 5 --stop-stage 5" exit 1 fi - if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ]; then + if [ ! -d ../../librispeech/ASR/data/lang_bpe_500 ] && [ ! -d ./data/lang_bpe_500 ]; then log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 6 --stop-stage 6" exit 1 fi cd data/ - ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) . - ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) . - ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . + if [ ! -d ./lang_char ]; then + ln -svf $(realpath ../../../aishell2/ASR/data/lang_char) . + fi + if [ ! -d ./lang_phone ]; then + ln -svf $(realpath ../../../librispeech/ASR/data/lang_phone) . + fi + if [ ! -d ./lang_bpe_500 ]; then + ln -svf $(realpath ../../../librispeech/ASR/data/lang_bpe_500) . + fi cd ../ for vocab_size in ${vocab_sizes[@]}; do @@ -104,9 +111,11 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then cat data/lang_char/text data/lang_bpe_500/transcript_words.txt \ > $lang_dir/text - ./local/prepare_for_bpe_model.py \ - --lang-dir ./$lang_dir \ - --text $lang_dir/text + if [ ! -f $lang_dir/transcript_chars.txt ]; then + ./local/prepare_for_bpe_model.py \ + --lang-dir ./$lang_dir \ + --text $lang_dir/text + fi if [ ! -f $lang_dir/text_words_segmentation ]; then python3 ./local/text2segments.py \ @@ -115,6 +124,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then cat ./data/lang_bpe_500/transcript_words.txt \ >> $lang_dir/text_words_segmentation + + cat ./data/lang_char/text \ + >> $lang_dir/text fi cat $lang_dir/text_words_segmentation | sed 's/ /\n/g' \ @@ -130,7 +142,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then ./local/train_bbpe_model.py \ --lang-dir $lang_dir \ --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_chars.txt + --transcript $lang_dir/text fi if [ ! -f $lang_dir/L_disambig.pt ]; then diff --git a/egs/multi_zh_en/ASR/zipformer/decode.py b/egs/multi_zh_en/ASR/zipformer/decode.py index f501c3c30..a54381896 100755 --- a/egs/multi_zh_en/ASR/zipformer/decode.py +++ b/egs/multi_zh_en/ASR/zipformer/decode.py @@ -120,6 +120,7 @@ from lhotse.cut import Cut from multi_dataset import MultiDataset from train import add_model_arguments, get_model, get_params +from icefall import byte_encode, byte_decode from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -193,14 +194,14 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - default="data/lang_bpe_2000/bpe.model", + default="data/lang_bbpe_2000/bbpe.model", help="Path to the BPE model", ) parser.add_argument( "--lang-dir", type=Path, - default="data/lang_bpe_2000", + default="data/lang_bbpe_2000", help="The lang dir containing word table and LG graph", ) diff --git a/egs/multi_zh_en/ASR/zipformer/train.py b/egs/multi_zh_en/ASR/zipformer/train.py index d2b6de242..1c2040525 100755 --- a/egs/multi_zh_en/ASR/zipformer/train.py +++ b/egs/multi_zh_en/ASR/zipformer/train.py @@ -80,7 +80,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer2 -from icefall import diagnostics +from icefall import diagnostics, byte_encode from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -96,6 +96,7 @@ from icefall.utils import ( get_parameter_groups_with_lrs, setup_logger, str2bool, + tokenize_by_CJK_char, ) LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -1177,18 +1178,18 @@ def run(rank, world_size, args): train_cuts = multi_dataset.train_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds + # Keep only utterances with duration between 1 second and 12 seconds # - # Caution: There is a reason to select 20.0 here. Please see + # Caution: There is a reason to select 12.0 here. Please see # ../local/display_manifest_statistics.py # # You should use ../local/display_manifest_statistics.py to get # an utterance duration distribution for your dataset to select # the threshold - if c.duration < 1.0 or c.duration > 20.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) + if c.duration < 1.0 or c.duration > 12.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) return False # In pruned RNN-T, we require that T >= S @@ -1213,8 +1214,17 @@ def run(rank, world_size, args): return True + def tokenize_and_encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = byte_encode(tokenize_by_CJK_char(text)) + c.supervisions[0].text = text + return c + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.map(tokenize_and_encode_text) + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint # saved in the middle of an epoch