mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add char-based language model training process for aishell. (#945)
* Add char-based language model training process for aishell. Add soft link from librispeech/ASR/local/sort_lm_training_data.py to aishell/ASR/local/ --------- Co-authored-by: lichao <www.563042811@qq.com>
This commit is contained in:
parent
a48812ddb3
commit
6196b4a407
164
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file
164
egs/aishell/ASR/local/prepare_char_lm_training_data.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey
|
||||||
|
# Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes a `tokens.txt` and a text file such as
|
||||||
|
./download/lm/aishell-transcript.txt
|
||||||
|
and outputs the LM training data to a supplied directory such
|
||||||
|
as data/lm_training_char. The format is as follows:
|
||||||
|
It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a
|
||||||
|
representation of a dict with the same format with librispeech receipe
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-char",
|
||||||
|
type=str,
|
||||||
|
help="""Lang dir of asr model, e.g. data/lang_char""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-data",
|
||||||
|
type=str,
|
||||||
|
help="""Input LM training data as text, e.g.
|
||||||
|
download/lm/aishell-train-word.txt""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-archive",
|
||||||
|
type=str,
|
||||||
|
help="""Path to output archive, e.g. data/lm_training_char/lm_data.pt;
|
||||||
|
look at the source of this script to see the format.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
if Path(args.lm_archive).exists():
|
||||||
|
logging.warning(f"{args.lm_archive} exists - skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
# make token_dict from tokens.txt in order to map characters to tokens.
|
||||||
|
token_dict = {}
|
||||||
|
token_file = args.lang_char + "/tokens.txt"
|
||||||
|
|
||||||
|
with open(token_file, "r") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
line_list = line.split()
|
||||||
|
token_dict[line_list[0]] = int(line_list[1])
|
||||||
|
|
||||||
|
# word2index is a dictionary from words to integer ids. No need to reserve
|
||||||
|
# space for epsilon, etc.; the words are just used as a convenient way to
|
||||||
|
# compress the sequences of tokens.
|
||||||
|
word2index = dict()
|
||||||
|
|
||||||
|
word2token = [] # Will be a list-of-list-of-int, representing tokens.
|
||||||
|
sentences = [] # Will be a list-of-list-of-int, representing word-ids.
|
||||||
|
|
||||||
|
if "aishell-lm" in args.lm_data:
|
||||||
|
num_lines_in_total = 120098.0
|
||||||
|
step = 50000
|
||||||
|
elif "valid" in args.lm_data:
|
||||||
|
num_lines_in_total = 14326.0
|
||||||
|
step = 3000
|
||||||
|
elif "test" in args.lm_data:
|
||||||
|
num_lines_in_total = 7176.0
|
||||||
|
step = 3000
|
||||||
|
else:
|
||||||
|
num_lines_in_total = None
|
||||||
|
step = None
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
with open(args.lm_data) as f:
|
||||||
|
while True:
|
||||||
|
line = f.readline()
|
||||||
|
if line == "":
|
||||||
|
break
|
||||||
|
|
||||||
|
if step and processed % step == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Processed number of lines: {processed} "
|
||||||
|
f"({processed / num_lines_in_total * 100: .3f}%)"
|
||||||
|
)
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
line_words = line.split()
|
||||||
|
for w in line_words:
|
||||||
|
if w not in word2index:
|
||||||
|
w_token = []
|
||||||
|
for t in w:
|
||||||
|
if t in token_dict:
|
||||||
|
w_token.append(token_dict[t])
|
||||||
|
else:
|
||||||
|
w_token.append(token_dict["<unk>"])
|
||||||
|
word2index[w] = len(word2token)
|
||||||
|
word2token.append(w_token)
|
||||||
|
sentences.append([word2index[w] for w in line_words])
|
||||||
|
|
||||||
|
logging.info("Constructing ragged tensors")
|
||||||
|
words = k2.ragged.RaggedTensor(word2token)
|
||||||
|
sentences = k2.ragged.RaggedTensor(sentences)
|
||||||
|
|
||||||
|
output = dict(words=words, sentences=sentences)
|
||||||
|
|
||||||
|
num_sentences = sentences.dim0
|
||||||
|
logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}")
|
||||||
|
sentence_lengths = [0] * num_sentences
|
||||||
|
for i in range(num_sentences):
|
||||||
|
if step and i % step == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Processed number of lines: {i} ({i / num_sentences * 100: .3f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
word_ids = sentences[i]
|
||||||
|
|
||||||
|
# NOTE: If word_ids is a tensor with only 1 entry,
|
||||||
|
# token_ids is a torch.Tensor
|
||||||
|
token_ids = words[word_ids]
|
||||||
|
if isinstance(token_ids, k2.RaggedTensor):
|
||||||
|
token_ids = token_ids.values
|
||||||
|
|
||||||
|
# token_ids is a 1-D tensor containing the BPE tokens
|
||||||
|
# of the current sentence
|
||||||
|
|
||||||
|
sentence_lengths[i] = token_ids.numel()
|
||||||
|
|
||||||
|
output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32)
|
||||||
|
|
||||||
|
torch.save(output, args.lm_archive)
|
||||||
|
logging.info(f"Saved to {args.lm_archive}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -7,7 +7,7 @@ set -eou pipefail
|
|||||||
|
|
||||||
nj=15
|
nj=15
|
||||||
stage=-1
|
stage=-1
|
||||||
stop_stage=10
|
stop_stage=11
|
||||||
|
|
||||||
# We assume dl_dir (download dir) contains the following
|
# We assume dl_dir (download dir) contains the following
|
||||||
# directories and files. If not, they will be downloaded
|
# directories and files. If not, they will be downloaded
|
||||||
@ -219,3 +219,93 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
|||||||
./local/compile_hlg.py --lang-dir $lang_phone_dir
|
./local/compile_hlg.py --lang-dir $lang_phone_dir
|
||||||
./local/compile_hlg.py --lang-dir $lang_char_dir
|
./local/compile_hlg.py --lang-dir $lang_char_dir
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||||
|
log "Stage 9: Generate LM training data"
|
||||||
|
|
||||||
|
log "Processing char based data"
|
||||||
|
out_dir=data/lm_training_char
|
||||||
|
mkdir -p $out_dir $dl_dir/lm
|
||||||
|
|
||||||
|
if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then
|
||||||
|
cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $dl_dir/lm/aishell-train-word.txt \
|
||||||
|
--lm-archive $out_dir/lm_data.pt
|
||||||
|
|
||||||
|
if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then
|
||||||
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
|
aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid
|
||||||
|
find $dl_dir/aishell/data_aishell/wav/dev -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_valid_uid
|
||||||
|
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_valid_uid $aishell_text |
|
||||||
|
cut -d " " -f 2- > $dl_dir/lm/aishell-valid-word.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $dl_dir/lm/aishell-valid-word.txt \
|
||||||
|
--lm-archive $out_dir/lm_data_valid.pt
|
||||||
|
|
||||||
|
if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then
|
||||||
|
aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt
|
||||||
|
aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid
|
||||||
|
find $dl_dir/aishell/data_aishell/wav/test -name "*.wav" | sed 's/\.wav//g' | awk -F '/' '{print $NF}' > $aishell_test_uid
|
||||||
|
awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $aishell_test_uid $aishell_text |
|
||||||
|
cut -d " " -f 2- > $dl_dir/lm/aishell-test-word.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
./local/prepare_char_lm_training_data.py \
|
||||||
|
--lang-char data/lang_char \
|
||||||
|
--lm-data $dl_dir/lm/aishell-test-word.txt \
|
||||||
|
--lm-archive $out_dir/lm_data_test.pt
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||||
|
log "Stage 10: Sort LM training data"
|
||||||
|
# Sort LM training data by sentence length in descending order
|
||||||
|
# for ease of training.
|
||||||
|
#
|
||||||
|
# Sentence length equals to the number of tokens
|
||||||
|
# in a sentence.
|
||||||
|
|
||||||
|
out_dir=data/lm_training_char
|
||||||
|
mkdir -p $out_dir
|
||||||
|
ln -snf ../../../librispeech/ASR/local/sort_lm_training_data.py local/
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $out_dir/lm_data.pt \
|
||||||
|
--out-lm-data $out_dir/sorted_lm_data.pt \
|
||||||
|
--out-statistics $out_dir/statistics.txt
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $out_dir/lm_data_valid.pt \
|
||||||
|
--out-lm-data $out_dir/sorted_lm_data-valid.pt \
|
||||||
|
--out-statistics $out_dir/statistics-valid.txt
|
||||||
|
|
||||||
|
./local/sort_lm_training_data.py \
|
||||||
|
--in-lm-data $out_dir/lm_data_test.pt \
|
||||||
|
--out-lm-data $out_dir/sorted_lm_data-test.pt \
|
||||||
|
--out-statistics $out_dir/statistics-test.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||||
|
log "Stage 11: Train RNN LM model"
|
||||||
|
python ../../../icefall/rnn_lm/train.py \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--world-size 1 \
|
||||||
|
--num-epochs 20 \
|
||||||
|
--use-fp16 0 \
|
||||||
|
--embedding-dim 512 \
|
||||||
|
--hidden-dim 512 \
|
||||||
|
--num-layers 2 \
|
||||||
|
--batch-size 400 \
|
||||||
|
--exp-dir rnnlm_char/exp \
|
||||||
|
--lm-data data/lm_training_char/sorted_lm_data.pt \
|
||||||
|
--lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \
|
||||||
|
--vocab-size 4336 \
|
||||||
|
--master-port 12345
|
||||||
|
fi
|
||||||
|
Loading…
x
Reference in New Issue
Block a user