add and update some files

This commit is contained in:
luomingshuang 2022-04-08 19:00:19 +08:00
parent 5242275b8a
commit d8907769ee
10 changed files with 1117 additions and 131 deletions

View File

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()`
in ../../../librispeech/ASR/transducer/train.py
for usage.
"""
from lhotse import load_manifest
def main():
paths = [ # "./data/fbank/cuts_L_100_pieces.jsonl.gz",
"./data/fbank/cuts_L_50_pieces.jsonl.gz",
# "./data/fbank/cuts_DEV.jsonl.gz",
# "./data/fbank/cuts_TEST_NET.jsonl.gz",
# "./data/fbank/cuts_TEST_MEETING.jsonl.gz"
]
for path in paths:
print(f"Starting display the statistics for {path}")
cuts = load_manifest(path)
cuts.describe()
if __name__ == "__main__":
main()
"""
Starting display the statistics for ./data/fbank/cuts_L_50_pieces.jsonl.gz
Cuts count: 2241476
Total duration (hours): 1475.0
Speech duration (hours): 1475.0 (100.0%)
***
Duration statistics (seconds):
mean 2.4
std 1.6
min 0.3
25% 1.3
50% 2.0
75% 2.9
99% 8.2
99.5% 9.3
99.9% 13.5
max 87.0
Starting display the statistics for ./data/fbank/cuts_L_100_pieces.jsonl.gz
Cuts count: 4929619
Total duration (hours): 3361.1
Speech duration (hours): 3361.1 (100.0%)
***
Duration statistics (seconds):
mean 2.5
std 1.7
min 0.3
25% 1.4
50% 2.0
75% 3.0
99% 8.1
99.5% 8.8
99.9% 14.7
max 87.0
Starting display the statistics for ./data/fbank/cuts_DEV.jsonl.gz
Cuts count: 13825
Total duration (hours): 20.0
Speech duration (hours): 20.0 (100.0%)
***
Duration statistics (seconds):
mean 5.2
std 2.2
min 1.0
25% 3.3
50% 4.9
75% 7.0
99% 9.6
99.5% 9.8
99.9% 10.0
max 10.0
Starting display the statistics for ./data/fbank/cuts_TEST_NET.jsonl.gz
Cuts count: 24774
Total duration (hours): 23.1
Speech duration (hours): 23.1 (100.0%)
***
Duration statistics (seconds):
mean 3.4
std 2.6
min 0.1
25% 1.4
50% 2.4
75% 4.8
99% 13.1
99.5% 14.5
99.9% 18.5
max 33.3
Starting display the statistics for ./data/fbank/cuts_TEST_MEETING.jsonl.gz
Cuts count: 8370
Total duration (hours): 15.2
Speech duration (hours): 15.2 (100.0%)
***
Duration statistics (seconds):
mean 6.5
std 3.5
min 0.8
25% 3.7
50% 5.8
75% 8.8
99% 15.2
99.5% 16.0
99.9% 18.8
max 24.6
"""

View File

@ -0,0 +1,246 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/text,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
import re
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for word in words:
chars = list(word.strip(" \t"))
if contain_oov(token_sym_table, chars):
continue
lexicon.append((word, chars))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(text_file: str) -> Dict[str, int]:
"""Generate tokens from the given text file.
Args:
text_file:
A file that contains text lines to generate tokens.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
whitespace = re.compile(r"([ \t\r\n]+)")
with open(text_file, "r", encoding="utf-8") as f:
for line in f:
line = re.sub(whitespace, "", line)
tokens_list = list(line)
for token in tokens_list:
if token not in tokens:
tokens[token] = len(tokens)
return tokens
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
args = parser.parse_args()
lang_dir = Path(args.lang_dir)
text_file = lang_dir / "text"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
token_sym_table = generate_tokens(text_file)
lexicon = generate_lexicon(token_sym_table, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,260 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input `lang_dir`, which should contain::
- lang_dir/text,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt
- lexicon_disambig.txt
- L.pt
- L_disambig.pt
- tokens.txt
"""
import argparse
from pathlib import Path
from typing import Dict, List
import k2
import torch
from prepare_lang import (
Lexicon,
add_disambig_symbols,
add_self_loops,
write_lexicon,
write_mapping,
)
from pypinyin import pinyin
from tqdm import tqdm
def lexicon_to_fst_no_sil(
lexicon: Lexicon,
token2id: Dict[str, int],
word2id: Dict[str, int],
need_self_loops: bool = False,
) -> k2.Fsa:
"""Convert a lexicon to an FST (in k2 format).
Args:
lexicon:
The input lexicon. See also :func:`read_lexicon`
token2id:
A dict mapping tokens to IDs.
word2id:
A dict mapping words to IDs.
need_self_loops:
If True, add self-loop to states with non-epsilon output symbols
on at least one arc out of the state. The input label for this
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
Returns:
Return an instance of `k2.Fsa` representing the given lexicon.
"""
loop_state = 0 # words enter and leave from here
next_state = 1 # the next un-allocated state, will be incremented as we go
arcs = []
# The blank symbol <blk> is defined in local/train_bpe_model.py
assert token2id["<blk>"] == 0
assert word2id["<eps>"] == 0
eps = 0
for word, pieces in lexicon:
assert len(pieces) > 0, f"{word} has no pronunciations"
cur_state = loop_state
word = word2id[word]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
arcs.append([cur_state, next_state, pieces[i], w, 0])
cur_state = next_state
next_state += 1
# now for the last piece of this word
i = len(pieces) - 1
w = word if i == 0 else eps
arcs.append([cur_state, loop_state, pieces[i], w, 0])
if need_self_loops:
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
arcs.append([loop_state, final_state, -1, -1, 0])
arcs.append([final_state])
arcs = sorted(arcs, key=lambda arc: arc[0])
arcs = [[str(i) for i in arc] for arc in arcs]
arcs = [" ".join(arc) for arc in arcs]
arcs = "\n".join(arcs)
fsa = k2.Fsa.from_str(arcs, acceptor=False)
return fsa
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Check if all the given tokens are in token symbol table.
Args:
token_sym_table:
Token symbol table that contains all the valid tokens.
tokens:
A list of tokens.
Returns:
Return True if there is any token not in the token_sym_table,
otherwise False.
"""
for tok in tokens:
if tok not in token_sym_table:
return True
return False
def generate_lexicon(
token_sym_table: Dict[str, int], words: List[str]
) -> Lexicon:
"""Generate a lexicon from a word list and token_sym_table.
Args:
token_sym_table:
Token symbol table that mapping token to token ids.
words:
A list of strings representing words.
Returns:
Return a dict whose keys are words and values are the corresponding
tokens.
"""
lexicon = []
for i in tqdm(range(len(words))):
word = words[i]
tokens = []
pinyins = pinyin(word.strip(" \t"))
for pinyin_one in pinyins:
if pinyin_one[0].isupper():
tokens.extend(list(pinyin_one[0]))
else:
tokens.append(pinyin_one[0])
if contain_oov(token_sym_table, tokens):
continue
lexicon.append((word, tokens))
# The OOV word is <UNK>
lexicon.append(("<UNK>", ["<unk>"]))
return lexicon
def generate_tokens(words: List[str]) -> Dict[str, int]:
"""Generate tokens from the given text file.
Args:
words:
The list of words after removing <eps>, !SIL, and so on.
Returns:
Return a dict whose keys are tokens and values are token ids ranged
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
for i in tqdm(range(len(words))):
word = words[i]
pinyins_list = pinyin(word)
for pinyin_one in pinyins_list:
if pinyin_one[0].isupper():
tokens_list = list(pinyin_one[0])
else:
tokens_list = pinyin_one
for token in tokens_list:
if token not in tokens:
tokens[token] = len(tokens)
tokens = sorted(tokens.items(), key=lambda item: item[0])
tokens = dict(tokens)
return tokens
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
args = parser.parse_args()
lang_dir = Path(args.lang_dir)
words_file = lang_dir / "words.txt"
word_sym_table = k2.SymbolTable.from_file(words_file)
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
token_sym_table = generate_tokens(words)
lexicon = generate_lexicon(token_sym_table, words)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
next_token_id = max(token_sym_table.values()) + 1
for i in range(max_disambig + 1):
disambig = f"#{i}"
assert disambig not in token_sym_table
token_sym_table[disambig] = next_token_id
next_token_id += 1
word_sym_table.add("#0")
word_sym_table.add("<s>")
word_sym_table.add("</s>")
write_mapping(lang_dir / "tokens.txt", token_sym_table)
write_lexicon(lang_dir / "lexicon.txt", lexicon)
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(
lexicon_disambig,
token2id=token_sym_table,
word2id=word_sym_table,
need_self_loops=True,
)
torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input words.txt without ids:
- words_no_ids.txt
and generates the new words.txt with related ids.
- words.txt
"""
import argparse
import logging
from tqdm import tqdm
def get_parser():
parser = argparse.ArgumentParser(
description="Prepare words.txt",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-file",
default="data/lang_char/words_no_ids.txt",
type=str,
help="the words file without ids for WenetSpeech",
)
parser.add_argument(
"--output-file",
default="data/lang_char/words.txt",
type=str,
help="the words file with ids for WenetSpeech",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
add_words = ["<eps> 0", "!SIL 1", "<SPOKEN_NOISE> 2", "<UNK> 3"]
new_lines.extend(add_words)
logging.info("Starting reading the input file")
for i in tqdm(range(len(lines))):
x = lines[i]
idx = 4 + i
new_line = str(x.strip("\n")) + " " + str(idx)
new_lines.append(new_line)
logging.info("Starting writing the words.txt")
f_out = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_out.write(line)
f_out.write("\n")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input "text", which refers to the transcript file for
WenetSpeech:
- text
and generates the output file text_word_segmentation which is implemented
with word segmenting:
- text_words_segmentation
"""
import argparse
import jieba
from tqdm import tqdm
jieba.enable_paddle()
def get_parser():
parser = argparse.ArgumentParser(
description="Chinese Word Segmentation for text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input",
default="data/lang_char/text",
type=str,
help="the input text file for WenetSpeech",
)
parser.add_argument(
"--output",
default="data/lang_char/text_words_segmentation",
type=str,
help="the text implemented with words segmenting for WenetSpeech",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
input_file = args.input
output_file = args.output
f = open(input_file, "r", encoding="utf-8")
lines = f.readlines()
new_lines = []
for i in tqdm(range(len(lines))):
x = lines[i].rstrip()
seg_list = jieba.cut(x, use_paddle=True)
new_line = " ".join(seg_list)
new_lines.append(new_line)
f_new = open(output_file, "w", encoding="utf-8")
for line in new_lines:
f_new.write(line)
f_new.write("\n")
if __name__ == "__main__":
main()

View File

@ -188,7 +188,7 @@ def main():
a_chars = [z.replace(" ", args.space) for z in a_flat]
print(" ".join(a_chars))
print("".join(a_chars))
line = f.readline()

View File

@ -117,9 +117,9 @@ fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Combine features for L"
if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then
pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz")
lhotse combine $pieces data/fbank/cuts_L.jsonl.gz
if [ ! -f data/fbank/cuts_L_50.jsonl.gz ]; then
pieces=$(find data/fbank/L_split_50 -name "cuts_L.*.jsonl.gz")
lhotse combine $pieces data/fbank/cuts_L_50.jsonl.gz
fi
fi
@ -134,120 +134,50 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
lang_char_dir=data/lang_char
mkdir -p $lang_char_dir
gunzip -c data/manifests/supervisions_L.jsonl.gz \
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text
# Prepare text.
if [ ! -f $lang_char_dir/text ]; then
gunzip -c data/manifests/supervisions_L.jsonl.gz \
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "char" > $lang_char_dir/text
fi
cat $lang_char_dir/text | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_char_dir/words.txt
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
cat - $lang_char_dir/words.txt | sort | uniq | awk '
BEGIN {
print "<eps> 0";
}
{
if ($1 == "<s>") {
print "<s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
if ($1 == "</s>") {
print "</s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $lang_char_dir/words || exit 1;
# The implementation of chinese word segmentation for text,
# and it will take about 15 minutes.
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
python ./local/text2segments.py \
--input $lang_char_dir/text \
--output $lang_char_dir/text_words_segmentation
fi
mv $lang_char_dir/words $lang_char_dir/words.txt
cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt
if [ ! -f $lang_char_dir/words.txt ]; then
python ./local/prepare_words.py \
--input-file $lang_char_dir/words_no_ids.txt \
--output-file $lang_char_dir/words.txt
fi
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Prepare pinyin based lang"
lang_pinyin_dir=data/lang_pinyin
mkdir -p $lang_pinyin_dir
gunzip -c data/manifests/supervisions_L.jsonl.gz \
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "pinyin" > $lang_pinyin_dir/text
cat $lang_pinyin_dir/text | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_pinyin_dir/words.txt
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
cat - $lang_pinyin_dir/words.txt | sort | uniq | awk '
BEGIN {
print "<eps> 0";
}
{
if ($1 == "<s>") {
print "<s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
if ($1 == "</s>") {
print "</s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $lang_pinyin_dir/words || exit 1;
mv $lang_pinyin_dir/words $lang_pinyin_dir/words.txt
log "Stage 10: Prepare char based L_disambig.pt"
if [ ! -f data/lang_char/L_disambig.pt ]; then
python ./local/prepare_char.py \
--lang-dir data/lang_char
fi
fi
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
log "Stage 11: Prepare lazy_pinyin based lang"
lang_lazy_pinyin_dir=data/lang_lazy_pinyin
mkdir -p $lang_lazy_pinyin_dir
log "Stage 11: Prepare pinyin based L_disambig.pt"
lang_pinyin_dir=data/lang_pinyin
mkdir -p $lang_pinyin_dir
gunzip -c data/manifests/supervisions_L.jsonl.gz \
| jq '.text' | sed 's/"//g' \
| ./local/text2token.py -t "lazy_pinyin" > $lang_lazy_pinyin_dir/text
cat $lang_lazy_pinyin_dir/text | sed 's/ /\n/g' \
| sort -u | sed '/^$/d' > $lang_lazy_pinyin_dir/words.txt
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
cat - $lang_lazy_pinyin_dir/words.txt | sort | uniq | awk '
BEGIN {
print "<eps> 0";
}
{
if ($1 == "<s>") {
print "<s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
if ($1 == "</s>") {
print "</s> is in the vocabulary!" | "cat 1>&2"
exit 1;
}
printf("%s %d\n", $1, NR);
}
END {
printf("#0 %d\n", NR+1);
printf("<s> %d\n", NR+2);
printf("</s> %d\n", NR+3);
}' > $lang_lazy_pinyin_dir/words || exit 1;
mv $lang_lazy_pinyin_dir/words $lang_lazy_pinyin_dir/words.txt
fi
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
log "Stage 12: Prepare L_disambig.pt"
if [ ! -f data/lang_char/L_disambig.pt ]; then
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_char
fi
cp -r data/lang_char/words.txt $lang_pinyin_dir/
cp -r data/lang_char/text $lang_pinyin_dir/
cp -r data/lang_char/text_words_segmentation $lang_pinyin_dir/
if [ ! -f data/lang_pinyin/L_disambig.pt ]; then
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_pinyin
fi
if [ ! -f data/lang_lazy_pinyin/L_disambig.pt ]; then
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_lazy_pinyin
python ./local/prepare_pinyin.py \
--lang-dir data/lang_pinyin
fi
fi

View File

@ -375,11 +375,15 @@ class WenetSpeechAsrDataModule:
if self.args.lazy_load:
logging.info("use lazy cuts")
cuts_train = CutSet.from_jsonl_lazy(
self.args.manifest_dir / "cuts_L.jsonl.gz"
self.args.manifest_dir
/ "cuts_L_50_pieces.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
)
else:
cuts_train = CutSet.from_file(
self.args.manifest_dir / "cuts_L.jsonl.gz"
self.args.manifest_dir
/ "cuts_L_50_pieces.jsonl.gz"
# use cuts_L_50_pieces.jsonl.gz for original experiments
)
return cuts_train

View File

@ -25,8 +25,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--max-duration 300
--lang-dir data/lang_char \
--exp-dir pruned_transducer_stateless/exp-char \
--token-type char \
--max-duration 200
"""
@ -44,8 +46,8 @@ from asr_datamodule import WenetSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from local.text2token import token2id
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
@ -59,6 +61,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.pinyin_graph_compiler import PinyinCtcTrainingGraphCompiler
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
@ -108,7 +111,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
default="pruned_transducer_stateless_pinyin/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -118,7 +121,7 @@ def get_parser():
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_lazy_pinyin",
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
@ -128,10 +131,9 @@ def get_parser():
parser.add_argument(
"--token-type",
type=str,
default="lazy_pinyin",
help="""The token type
It refers to the token type for modeling, such as
char, pinyin, lazy_pinyin.
default="char",
help="""The type of token
It must be in ["char", "pinyin", "lazy_pinyin"].
""",
)
@ -435,16 +437,12 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = ""
if params.token_type == "char":
y = graph_compiler.texts_to_ids(texts)
y = graph_compiler.texts_to_ids(texts)
if type(y) == list:
y = k2.RaggedTensor(y).to(device)
else:
y = token2id(
texts=texts,
token_table=graph_compiler.token_table,
token_type=params.token_type,
)
y = k2.RaggedTensor(y).to(device)
y = y.to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
@ -635,9 +633,19 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon, device=device, oov="<unk>"
)
graph_compiler = ""
if params.token_type == "char":
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
if params.token_type == "pinyin":
graph_compiler = PinyinCtcTrainingGraphCompiler(
lang_dir=params.lang_dir,
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
@ -672,6 +680,23 @@ def run(rank, world_size, args):
train_cuts = wenetspeech.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 15.0 seconds
# You can get the statistics by local/display_manifest_statistics.py.
return 1.0 <= c.duration <= 15.0
def text_to_words(c: Cut):
# Convert text to words_segments.
text = c.supervisions[0].text
text = text.strip("\n").strip("\t")
words_cut = graph_compiler.text2words[text]
c.supervisions[0].text = words_cut
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.token_type == "pinyin":
train_cuts = train_cuts.map(text_to_words)
train_dl = wenetspeech.train_dataloaders(train_cuts)
valid_cuts = wenetspeech.valid_cuts()
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)

View File

@ -0,0 +1,219 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Dict, List
import k2
import torch
from tqdm import tqdm
from icefall.lexicon import Lexicon, read_lexicon
class PinyinCtcTrainingGraphCompiler(object):
def __init__(
self,
lang_dir: Path,
lexicon: Lexicon,
device: torch.device,
sos_token: str = "<sos/eos>",
eos_token: str = "<sos/eos>",
oov: str = "<unk>",
):
"""
Args:
lexicon:
It is built from `data/lang_char/lexicon.txt`.
device:
The device to use for operations compiling transcripts to FSAs.
oov:
Out of vocabulary token. When a word(token) in the transcript
does not exist in the token list, it is replaced with `oov`.
"""
assert oov in lexicon.token_table
self.lang_dir = lang_dir
self.oov_id = lexicon.token_table[oov]
self.token_table = lexicon.token_table
self.device = device
self.sos_id = self.token_table[sos_token]
self.eos_id = self.token_table[eos_token]
self.word_table = lexicon.word_table
self.token_table = lexicon.token_table
self.text2words = convert_text_to_word_segments(
text_filename=self.lang_dir / "text",
words_segments_filename=self.lang_dir / "text_words_segmentation",
)
self.ragged_lexicon = convert_lexicon_to_ragged(
filename=self.lang_dir / "lexicon.txt",
word_table=self.word_table,
token_table=self.token_table,
)
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of pinyin-based token IDs.
Args:
texts:
It is a list of strings.
An example containing two strings is given below:
['你好中国', '北京欢迎您']
Returns:
Return a list-of-list of pinyin-based token IDs.
"""
word_ids_list = []
for i in range(len(texts)):
word_ids = []
text = texts[i].strip("\n").strip("\t")
for word in text.split(" "):
if word in self.word_table:
word_ids.append(self.word_table[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
ans = self.ragged_lexicon.index(ragged_indexes)
ans = ans.remove_axis(ans.num_axes - 2)
return ans
def compile(
self,
token_ids: List[List[int]],
modified: bool = False,
) -> k2.Fsa:
"""Build a ctc graph from a list-of-list token IDs.
Args:
piece_ids:
It is a list-of-list integer IDs.
modified:
See :func:`k2.ctc_graph` for its meaning.
Return:
Return an FsaVec, which is the result of composing a
CTC topology with linear FSAs constructed from the given
piece IDs.
"""
return k2.ctc_graph(token_ids, modified=modified, device=self.device)
def convert_lexicon_to_ragged(
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
) -> k2.RaggedTensor:
"""Read a lexicon and convert lexicon to a ragged tensor.
Args:
filename:
Path to the lexicon file.
word_table:
The word symbol table.
token_table:
The token symbol table.
Returns:
A k2 ragged tensor with two axes [word][token].
"""
num_words = len(word_table.symbols)
excluded_words = [
"<eps>",
"!SIL",
"<SPOKEN_NOISE>",
"<UNK>",
"#0",
"<s>",
"</s>",
]
row_splits = [0]
token_ids_list = []
lexicon_tmp = read_lexicon(filename)
lexicon = dict(lexicon_tmp)
if len(lexicon_tmp) != len(lexicon):
raise RuntimeError(
"It's assumed that each word has a unique pronunciation"
)
for i in range(num_words):
w = word_table[i]
if w in excluded_words:
row_splits.append(row_splits[-1])
continue
tokens = lexicon[w]
token_ids = [token_table[k] for k in tokens]
row_splits.append(row_splits[-1] + len(token_ids))
token_ids_list.extend(token_ids)
cached_tot_size = row_splits[-1]
row_splits = torch.tensor(row_splits, dtype=torch.int32)
shape = k2.ragged.create_ragged_shape2(
row_splits,
None,
cached_tot_size,
)
values = torch.tensor(token_ids_list, dtype=torch.int32)
return k2.RaggedTensor(shape, values)
def convert_text_to_word_segments(
text_filename: str, words_segments_filename: str
) -> Dict[str, str]:
"""Convert text to word-based segments.
Args:
text_filename:
The file for the original transcripts.
words_segments_filename:
The file after implementing chinese word segmentation
for the original transcripts.
Returns:
A dictionary about text and words_segments.
"""
text2words = {}
f_text = open(text_filename, "r", encoding="utf-8")
text_lines = f_text.readlines()
text_lines = [line.strip("\t") for line in text_lines]
f_words = open(words_segments_filename, "r", encoding="utf-8")
words_lines = f_words.readlines()
words_lines = [line.strip("\t") for line in words_lines]
if len(text_lines) != len(words_lines):
raise RuntimeError(
"The lengths of text and words_segments should be equal."
)
for i in tqdm(range(len(text_lines))):
text = text_lines[i].strip(" ").strip("\n")
words_segments = words_lines[i].strip(" ").strip("\n")
text2words[text] = words_segments
return text2words