diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py new file mode 100644 index 000000000..e7995680b --- /dev/null +++ b/egs/aishell/ASR/local/prepare_char_lm_training_data.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `tokens.txt` and a text file such as +./download/lm/aishell-transcript.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_char. The format is as follows: +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the same format with librispeech receipe +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-char", + type=str, + help="""Lang dir of asr model, e.g. data/lang_char""", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/lm/aishell-train-word.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/lm_training_char/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + # make token_dict from tokens.txt in order to map characters to tokens. + token_dict = {} + token_file = args.lang_char + "/tokens.txt" + + with open(token_file, "r") as f: + for line in f.readlines(): + line_list = line.split() + token_dict[line_list[0]] = int(line_list[1]) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of tokens. + word2index = dict() + + word2token = [] # Will be a list-of-list-of-int, representing tokens. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + if "aishell-lm" in args.lm_data: + num_lines_in_total = 120098.0 + step = 50000 + elif "valid" in args.lm_data: + num_lines_in_total = 14326.0 + step = 3000 + elif "test" in args.lm_data: + num_lines_in_total = 7176.0 + step = 3000 + else: + num_lines_in_total = None + step = None + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed / num_lines_in_total * 100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_token = [] + for t in w: + if t in token_dict: + w_token.append(token_dict[t]) + else: + w_token.append(token_dict[""]) + word2index[w] = len(word2token) + word2token.append(w_token) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2token) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} ({i / num_sentences * 100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 5917668a1..cf4ee7818 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -7,7 +7,7 @@ set -eou pipefail nj=15 stage=-1 -stop_stage=10 +stop_stage=11 # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -219,3 +219,93 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then ./local/compile_hlg.py --lang-dir $lang_phone_dir ./local/compile_hlg.py --lang-dir $lang_char_dir fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Generate LM training data" + + log "Processing char based data" + out_dir=data/lm_training_char + mkdir -p $out_dir $dl_dir/lm + + if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then + cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-train-word.txt \ + --lm-archive $out_dir/lm_data.pt + + if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid + find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-valid-word.txt \ + --lm-archive $out_dir/lm_data_valid.pt + + if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then + aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt + aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid + find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text | + cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $dl_dir/lm/aishell-test-word.txt \ + --lm-archive $out_dir/lm_data_test.pt +fi + + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of tokens + # in a sentence. + + out_dir=data/lm_training_char + mkdir -p $out_dir + ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/ + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data_test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 1 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 512 \ + --hidden-dim 512 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data data/lm_training_char/sorted_lm_data.pt \ + --lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12345 +fi