add and update some files

This commit is contained in:
luomingshuang 2022-04-08 19:00:19 +08:00
parent 5242275b8a
commit d8907769ee
10 changed files with 1117 additions and 131 deletions

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()`
in ../../../librispeech/ASR/transducer/train.py
for usage.
"""
from lhotse import load_manifest
def main():
paths = [ # "./data/fbank/cuts_L_100_pieces.jsonl.gz",
"./data/fbank/cuts_L_50_pieces.jsonl.gz",
# "./data/fbank/cuts_DEV.jsonl.gz",
# "./data/fbank/cuts_TEST_NET.jsonl.gz",
# "./data/fbank/cuts_TEST_MEETING.jsonl.gz"
]
for path in paths:
print(f"Starting display the statistics for {path}")
cuts = load_manifest(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
Starting display the statistics for ./data/fbank/cuts_L_50_pieces.jsonl.gz
Cuts count: 2241476
Total duration (hours): 1475.0
Speech duration (hours): 1475.0 (100.0%)
***
Duration statistics (seconds):
mean 2.4
std 1.6
min 0.3
25% 1.3
50% 2.0
75% 2.9
99% 8.2
99.5% 9.3
99.9% 13.5
max 87.0
Starting display the statistics for ./data/fbank/cuts_L_100_pieces.jsonl.gz
Cuts count: 4929619
Total duration (hours): 3361.1
Speech duration (hours): 3361.1 (100.0%)
***
Duration statistics (seconds):
mean 2.5
std 1.7
min 0.3
25% 1.4
50% 2.0
75% 3.0
99% 8.1
99.5% 8.8
99.9% 14.7
max 87.0
Starting display the statistics for ./data/fbank/cuts_DEV.jsonl.gz
Cuts count: 13825
Total duration (hours): 20.0
Speech duration (hours): 20.0 (100.0%)
***
Duration statistics (seconds):
mean 5.2
std 2.2
min 1.0
25% 3.3
50% 4.9
75% 7.0
99% 9.6
99.5% 9.8
99.9% 10.0
max 10.0
Starting display the statistics for ./data/fbank/cuts_TEST_NET.jsonl.gz
Cuts count: 24774
Total duration (hours): 23.1
Speech duration (hours): 23.1 (100.0%)
***
Duration statistics (seconds):
mean 3.4
std 2.6
min 0.1
25% 1.4
50% 2.4
75% 4.8
99% 13.1
99.5% 14.5
99.9% 18.5
max 33.3
Starting display the statistics for ./data/fbank/cuts_TEST_MEETING.jsonl.gz
Cuts count: 8370
Total duration (hours): 15.2
Speech duration (hours): 15.2 (100.0%)
***
Duration statistics (seconds):
mean 6.5
std 3.5
min 0.8
25% 3.7
50% 5.8
75% 8.8
99% 15.2
99.5% 16.0
99.9% 18.8
max 24.6
"""

View File

@ -0,0 +1,246 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/text,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
import re
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for word in words:
chars = list(word.strip(" \t"))
if contain_oov(token_sym_table, chars):
continue
lexicon.append((word, chars))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(text_file: str) -> Dict[str, int]:
"""Generate tokens from the given text file.
Args:
text_file:
A file that contains text lines to generate tokens.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
whitespace = re.compile(r"([ \t\r\n]+)")
with open(text_file, "r", encoding="utf-8") as f:
for line in f:
line = re.sub(whitespace, "", line)
tokens_list = list(line)
for token in tokens_list:
if token not in tokens:
tokens[token] = len(tokens)
return tokens
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
args = parser.parse_args()
lang_dir = Path(args.lang_dir)
text_file = lang_dir / "text"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
token_sym_table = generate_tokens(text_file)
lexicon = generate_lexicon(token_sym_table, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,260 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/text,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
from pypinyin import pinyin
from tqdm import tqdm
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for i in tqdm(range(len(words))):
word = words[i]
tokens = []
pinyins = pinyin(word.strip(" \t"))
for pinyin_one in pinyins:
if pinyin_one[0].isupper():
tokens.extend(list(pinyin_one[0]))
else:
tokens.append(pinyin_one[0])
if contain_oov(token_sym_table, tokens):
continue
lexicon.append((word, tokens))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(words: List[str]) -> Dict[str, int]:
"""Generate tokens from the given text file.
Args:
words:
The list of words after removing <eps>, !SIL, and so on.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
for i in tqdm(range(len(words))):
word = words[i]
pinyins_list = pinyin(word)
for pinyin_one in pinyins_list:
if pinyin_one[0].isupper():
tokens_list = list(pinyin_one[0])
else:
tokens_list = pinyin_one
for token in tokens_list:
if token not in tokens:
tokens[token] = len(tokens)
tokens = sorted(tokens.items(), key=lambda item: item[0])
tokens = dict(tokens)
return tokens
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
args = parser.parse_args()
lang_dir = Path(args.lang_dir)
words_file = lang_dir / "words.txt"
word_sym_table = k2.SymbolTable.from_file(words_file)
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
token_sym_table = generate_tokens(words)
lexicon = generate_lexicon(token_sym_table, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input words.txt without ids:
- words_no_ids.txt
and generates the new words.txt with related ids.
- words.txt
"""
import argparse
import logging
from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
description="Prepare words.txt",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file",
default="data/lang_char/words_no_ids.txt",
type=str,
help="the words file without ids for WenetSpeech",
)
parser.add_argument(
"--output-file",
default="data/lang_char/words.txt",
type=str,
help="the words file with ids for WenetSpeech",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
add_words = ["<eps> 0", "!SIL 1", "<SPOKEN_NOISE> 2", "<UNK> 3"]
new_lines.extend(add_words)
logging.info("Starting reading the input file")
for i in tqdm(range(len(lines))):
x = lines[i]
idx = 4 + i
new_line = str(x.strip("\n")) + " " + str(idx)
new_lines.append(new_line)
logging.info("Starting writing the words.txt")
f_out = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_out.write(line)
f_out.write("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input "text", which refers to the transcript file for
WenetSpeech:
- text
and generates the output file text_word_segmentation which is implemented
with word segmenting:
- text_words_segmentation
"""
import argparse
import jieba
from tqdm import tqdm
jieba.enable_paddle()
def get_parser():
parser = argparse.ArgumentParser(
description="Chinese Word Segmentation for text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input",
default="data/lang_char/text",
type=str,
help="the input text file for WenetSpeech",
)
parser.add_argument(
"--output",
default="data/lang_char/text_words_segmentation",
type=str,
help="the text implemented with words segmenting for WenetSpeech",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input
output_file = args.output
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
for i in tqdm(range(len(lines))):
x = lines[i].rstrip()
seg_list = jieba.cut(x, use_paddle=True)
new_line = " ".join(seg_list)
new_lines.append(new_line)
f_new = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_new.write(line)
f_new.write("\n")
if __name__ == "__main__":
main()

View File

@ -188,7 +188,7 @@ def main():
a_chars = [z.replace(" ", args.space) for z in a_flat] a_chars = [z.replace(" ", args.space) for z in a_flat]
print(" ".join(a_chars)) print("".join(a_chars))
line = f.readline() line = f.readline()

View File

@ -117,9 +117,9 @@ fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Combine features for L" log "Stage 7: Combine features for L"
if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then if [ ! -f data/fbank/cuts_L_50.jsonl.gz ]; then
pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz") pieces=$(find data/fbank/L_split_50 -name "cuts_L.*.jsonl.gz")
lhotse combine $pieces data/fbank/cuts_L.jsonl.gz lhotse combine $pieces data/fbank/cuts_L_50.jsonl.gz
fi fi
fi fi
@ -134,120 +134,50 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
lang_char_dir=data/lang_char lang_char_dir=data/lang_char
mkdir -p $lang_char_dir mkdir -p $lang_char_dir
gunzip -c data/manifests/supervisions_L.jsonl.gz \ # Prepare text.
| jq '.text' | sed 's/"//g' \ if [ ! -f $lang_char_dir/text ]; then
| ./local/text2token.py -t "char" > $lang_char_dir/text gunzip -c data/manifests/supervisions_L.jsonl.gz \
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text
fi
cat $lang_char_dir/text | sed 's/ /\n/g' \ # The implementation of chinese word segmentation for text,
| sort -u | sed '/^$/d' > $lang_char_dir/words.txt # and it will take about 15 minutes.
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) | if [ ! -f $lang_char_dir/text_words_segmentation ]; then
cat - $lang_char_dir/words.txt | sort | uniq | awk ' python ./local/text2segments.py \
BEGIN { --input $lang_char_dir/text \
print "<eps> 0"; --output $lang_char_dir/text_words_segmentation
} fi
{
if ($1 == "<s>") {
print "<s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
if ($1 == "</s>") {
print "</s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $lang_char_dir/words || exit 1;
mv $lang_char_dir/words $lang_char_dir/words.txt cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt
if [ ! -f $lang_char_dir/words.txt ]; then
python ./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt
fi
fi fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare pinyin based lang" log "Stage 10: Prepare char based L_disambig.pt"
lang_pinyin_dir=data/lang_pinyin if [ ! -f data/lang_char/L_disambig.pt ]; then
mkdir -p $lang_pinyin_dir python ./local/prepare_char.py \
--lang-dir data/lang_char
gunzip -c data/manifests/supervisions_L.jsonl.gz \ fi
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "pinyin" > $lang_pinyin_dir/text
cat $lang_pinyin_dir/text | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_pinyin_dir/words.txt
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
cat - $lang_pinyin_dir/words.txt | sort | uniq | awk '
BEGIN {
print "<eps> 0";
}
{
if ($1 == "<s>") {
print "<s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
if ($1 == "</s>") {
print "</s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $lang_pinyin_dir/words || exit 1;
mv $lang_pinyin_dir/words $lang_pinyin_dir/words.txt
fi fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare lazy_pinyin based lang" log "Stage 11: Prepare pinyin based L_disambig.pt"
lang_lazy_pinyin_dir=data/lang_lazy_pinyin lang_pinyin_dir=data/lang_pinyin
mkdir -p $lang_lazy_pinyin_dir mkdir -p $lang_pinyin_dir
gunzip -c data/manifests/supervisions_L.jsonl.gz \ cp -r data/lang_char/words.txt $lang_pinyin_dir/
| jq '.text' | sed 's/"//g' \ cp -r data/lang_char/text $lang_pinyin_dir/
| ./local/text2token.py -t "lazy_pinyin" > $lang_lazy_pinyin_dir/text cp -r data/lang_char/text_words_segmentation $lang_pinyin_dir/
cat $lang_lazy_pinyin_dir/text | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_lazy_pinyin_dir/words.txt
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
cat - $lang_lazy_pinyin_dir/words.txt | sort | uniq | awk '
BEGIN {
print "<eps> 0";
}
{
if ($1 == "<s>") {
print "<s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
if ($1 == "</s>") {
print "</s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $lang_lazy_pinyin_dir/words || exit 1;
mv $lang_lazy_pinyin_dir/words $lang_lazy_pinyin_dir/words.txt
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare L_disambig.pt"
if [ ! -f data/lang_char/L_disambig.pt ]; then
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_char
fi
if [ ! -f data/lang_pinyin/L_disambig.pt ]; then if [ ! -f data/lang_pinyin/L_disambig.pt ]; then
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_pinyin python ./local/prepare_pinyin.py \
fi --lang-dir data/lang_pinyin
if [ ! -f data/lang_lazy_pinyin/L_disambig.pt ]; then
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_lazy_pinyin
fi fi
fi fi

View File

@ -375,11 +375,15 @@ class WenetSpeechAsrDataModule:
if self.args.lazy_load: if self.args.lazy_load:
logging.info("use lazy cuts") logging.info("use lazy cuts")
cuts_train = CutSet.from_jsonl_lazy( cuts_train = CutSet.from_jsonl_lazy(
self.args.manifest_dir / "cuts_L.jsonl.gz" self.args.manifest_dir
/ "cuts_L_50_pieces.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
) )
else: else:
cuts_train = CutSet.from_file( cuts_train = CutSet.from_file(
self.args.manifest_dir / "cuts_L.jsonl.gz" self.args.manifest_dir
/ "cuts_L_50_pieces.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
) )
return cuts_train return cuts_train

View File

@ -25,8 +25,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \ --lang-dir data/lang_char \
--max-duration 300 --exp-dir pruned_transducer_stateless/exp-char \
--token-type char \
--max-duration 200
""" """
@ -44,8 +46,8 @@ from asr_datamodule import WenetSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from local.text2token import token2id
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -59,6 +61,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.pinyin_graph_compiler import PinyinCtcTrainingGraphCompiler
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -108,7 +111,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless/exp", default="pruned_transducer_stateless_pinyin/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -118,7 +121,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=str,
default="data/lang_lazy_pinyin", default="data/lang_char",
help="""The lang dir help="""The lang dir
It contains language related input files such as It contains language related input files such as
"lexicon.txt" "lexicon.txt"
@ -128,10 +131,9 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--token-type", "--token-type",
type=str, type=str,
default="lazy_pinyin", default="char",
help="""The token type help="""The type of token
It refers to the token type for modeling, such as It must be in ["char", "pinyin", "lazy_pinyin"].
char, pinyin, lazy_pinyin.
""", """,
) )
@ -435,16 +437,12 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = ""
if params.token_type == "char": y = graph_compiler.texts_to_ids(texts)
y = graph_compiler.texts_to_ids(texts) if type(y) == list:
y = k2.RaggedTensor(y).to(device)
else: else:
y = token2id( y = y.to(device)
texts=texts,
token_table=graph_compiler.token_table,
token_type=params.token_type,
)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss = model(
@ -635,9 +633,19 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler( graph_compiler = ""
lexicon=lexicon, device=device, oov="<unk>" if params.token_type == "char":
) graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
if params.token_type == "pinyin":
graph_compiler = PinyinCtcTrainingGraphCompiler(
lang_dir=params.lang_dir,
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
@ -672,6 +680,23 @@ def run(rank, world_size, args):
train_cuts = wenetspeech.train_cuts() train_cuts = wenetspeech.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 15.0 seconds
# You can get the statistics by local/display_manifest_statistics.py.
return 1.0 <= c.duration <= 15.0
def text_to_words(c: Cut):
# Convert text to words_segments.
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
words_cut = graph_compiler.text2words[text]
c.supervisions[0].text = words_cut
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.token_type == "pinyin":
train_cuts = train_cuts.map(text_to_words)
train_dl = wenetspeech.train_dataloaders(train_cuts) train_dl = wenetspeech.train_dataloaders(train_cuts)
valid_cuts = wenetspeech.valid_cuts() valid_cuts = wenetspeech.valid_cuts()
valid_dl = wenetspeech.valid_dataloaders(valid_cuts) valid_dl = wenetspeech.valid_dataloaders(valid_cuts)

View File

@ -0,0 +1,219 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict, List
import k2
import torch
from tqdm import tqdm
from icefall.lexicon import Lexicon, read_lexicon
class PinyinCtcTrainingGraphCompiler(object):
def __init__(
self,
lang_dir: Path,
lexicon: Lexicon,
device: torch.device,
sos_token: str = "<sos/eos>",
eos_token: str = "<sos/eos>",
oov: str = "<unk>",
):
"""
Args:
lexicon:
It is built from `data/lang_char/lexicon.txt`.
device:
The device to use for operations compiling transcripts to FSAs.
oov:
Out of vocabulary token. When a word(token) in the transcript
does not exist in the token list, it is replaced with `oov`.
"""
assert oov in lexicon.token_table
self.lang_dir = lang_dir
self.oov_id = lexicon.token_table[oov]
self.token_table = lexicon.token_table
self.device = device
self.sos_id = self.token_table[sos_token]
self.eos_id = self.token_table[eos_token]
self.word_table = lexicon.word_table
self.token_table = lexicon.token_table
self.text2words = convert_text_to_word_segments(
text_filename=self.lang_dir / "text",
words_segments_filename=self.lang_dir / "text_words_segmentation",
)
self.ragged_lexicon = convert_lexicon_to_ragged(
filename=self.lang_dir / "lexicon.txt",
word_table=self.word_table,
token_table=self.token_table,
)
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of pinyin-based token IDs.
Args:
texts:
It is a list of strings.
An example containing two strings is given below:
['你好中国', '北京欢迎您']
Returns:
Return a list-of-list of pinyin-based token IDs.
"""
word_ids_list = []
for i in range(len(texts)):
word_ids = []
text = texts[i].strip("\n").strip("\t")
for word in text.split(" "):
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
ans = self.ragged_lexicon.index(ragged_indexes)
ans = ans.remove_axis(ans.num_axes - 2)
return ans
def compile(
self,
token_ids: List[List[int]],
modified: bool = False,
) -> k2.Fsa:
"""Build a ctc graph from a list-of-list token IDs.
Args:
piece_ids:
It is a list-of-list integer IDs.
modified:
See :func:`k2.ctc_graph` for its meaning.
Return:
Return an FsaVec, which is the result of composing a
CTC topology with linear FSAs constructed from the given
piece IDs.
"""
return k2.ctc_graph(token_ids, modified=modified, device=self.device)
def convert_lexicon_to_ragged(
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
) -> k2.RaggedTensor:
"""Read a lexicon and convert lexicon to a ragged tensor.
Args:
filename:
Path to the lexicon file.
word_table:
The word symbol table.
token_table:
The token symbol table.
Returns:
A k2 ragged tensor with two axes [word][token].
"""
num_words = len(word_table.symbols)
excluded_words = [
"<eps>",
"!SIL",
"<SPOKEN_NOISE>",
"<UNK>",
"#0",
"<s>",
"</s>",
]
row_splits = [0]
token_ids_list = []
lexicon_tmp = read_lexicon(filename)
lexicon = dict(lexicon_tmp)
if len(lexicon_tmp) != len(lexicon):
raise RuntimeError(
"It's assumed that each word has a unique pronunciation"
)
for i in range(num_words):
w = word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
tokens = lexicon[w]
token_ids = [token_table[k] for k in tokens]
row_splits.append(row_splits[-1] + len(token_ids))
token_ids_list.extend(token_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits,
None,
cached_tot_size,
)
values = torch.tensor(token_ids_list, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
def convert_text_to_word_segments(
text_filename: str, words_segments_filename: str
) -> Dict[str, str]:
"""Convert text to word-based segments.
Args:
text_filename:
The file for the original transcripts.
words_segments_filename:
The file after implementing chinese word segmentation
for the original transcripts.
Returns:
A dictionary about text and words_segments.
"""
text2words = {}
f_text = open(text_filename, "r", encoding="utf-8")
text_lines = f_text.readlines()
text_lines = [line.strip("\t") for line in text_lines]
f_words = open(words_segments_filename, "r", encoding="utf-8")
words_lines = f_words.readlines()
words_lines = [line.strip("\t") for line in words_lines]
if len(text_lines) != len(words_lines):
raise RuntimeError(
"The lengths of text and words_segments should be equal."
)
for i in tqdm(range(len(text_lines))):
text = text_lines[i].strip(" ").strip("\n")
words_segments = words_lines[i].strip(" ").strip("\n")
text2words[text] = words_segments
return text2words