mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
add and update some files
This commit is contained in:
parent
5242275b8a
commit
d8907769ee
135
egs/wenetspeech/ASR/local/display_manifest_statistics.py
Normal file
135
egs/wenetspeech/ASR/local/display_manifest_statistics.py
Normal file
@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file displays duration statistics of utterances in a manifest.
|
||||
You can use the displayed value to choose minimum/maximum duration
|
||||
to remove short and long utterances during the training.
|
||||
See the function `remove_short_and_long_utt()`
|
||||
in ../../../librispeech/ASR/transducer/train.py
|
||||
for usage.
|
||||
"""
|
||||
|
||||
|
||||
from lhotse import load_manifest
|
||||
|
||||
|
||||
def main():
|
||||
paths = [ # "./data/fbank/cuts_L_100_pieces.jsonl.gz",
|
||||
"./data/fbank/cuts_L_50_pieces.jsonl.gz",
|
||||
# "./data/fbank/cuts_DEV.jsonl.gz",
|
||||
# "./data/fbank/cuts_TEST_NET.jsonl.gz",
|
||||
# "./data/fbank/cuts_TEST_MEETING.jsonl.gz"
|
||||
]
|
||||
|
||||
for path in paths:
|
||||
print(f"Starting display the statistics for {path}")
|
||||
cuts = load_manifest(path)
|
||||
cuts.describe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
"""
|
||||
Starting display the statistics for ./data/fbank/cuts_L_50_pieces.jsonl.gz
|
||||
Cuts count: 2241476
|
||||
Total duration (hours): 1475.0
|
||||
Speech duration (hours): 1475.0 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 2.4
|
||||
std 1.6
|
||||
min 0.3
|
||||
25% 1.3
|
||||
50% 2.0
|
||||
75% 2.9
|
||||
99% 8.2
|
||||
99.5% 9.3
|
||||
99.9% 13.5
|
||||
max 87.0
|
||||
|
||||
Starting display the statistics for ./data/fbank/cuts_L_100_pieces.jsonl.gz
|
||||
Cuts count: 4929619
|
||||
Total duration (hours): 3361.1
|
||||
Speech duration (hours): 3361.1 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 2.5
|
||||
std 1.7
|
||||
min 0.3
|
||||
25% 1.4
|
||||
50% 2.0
|
||||
75% 3.0
|
||||
99% 8.1
|
||||
99.5% 8.8
|
||||
99.9% 14.7
|
||||
max 87.0
|
||||
|
||||
Starting display the statistics for ./data/fbank/cuts_DEV.jsonl.gz
|
||||
Cuts count: 13825
|
||||
Total duration (hours): 20.0
|
||||
Speech duration (hours): 20.0 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 5.2
|
||||
std 2.2
|
||||
min 1.0
|
||||
25% 3.3
|
||||
50% 4.9
|
||||
75% 7.0
|
||||
99% 9.6
|
||||
99.5% 9.8
|
||||
99.9% 10.0
|
||||
max 10.0
|
||||
|
||||
Starting display the statistics for ./data/fbank/cuts_TEST_NET.jsonl.gz
|
||||
Cuts count: 24774
|
||||
Total duration (hours): 23.1
|
||||
Speech duration (hours): 23.1 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 3.4
|
||||
std 2.6
|
||||
min 0.1
|
||||
25% 1.4
|
||||
50% 2.4
|
||||
75% 4.8
|
||||
99% 13.1
|
||||
99.5% 14.5
|
||||
99.9% 18.5
|
||||
max 33.3
|
||||
|
||||
Starting display the statistics for ./data/fbank/cuts_TEST_MEETING.jsonl.gz
|
||||
Cuts count: 8370
|
||||
Total duration (hours): 15.2
|
||||
Speech duration (hours): 15.2 (100.0%)
|
||||
***
|
||||
Duration statistics (seconds):
|
||||
mean 6.5
|
||||
std 3.5
|
||||
min 0.8
|
||||
25% 3.7
|
||||
50% 5.8
|
||||
75% 8.8
|
||||
99% 15.2
|
||||
99.5% 16.0
|
||||
99.9% 18.8
|
||||
max 24.6
|
||||
|
||||
"""
|
246
egs/wenetspeech/ASR/local/prepare_char.py
Executable file
246
egs/wenetspeech/ASR/local/prepare_char.py
Executable file
@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input `lang_dir`, which should contain::
|
||||
- lang_dir/text,
|
||||
- lang_dir/words.txt
|
||||
and generates the following files in the directory `lang_dir`:
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||
|
||||
arcs = []
|
||||
|
||||
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||
assert token2id["<blk>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, pieces in lexicon:
|
||||
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [
|
||||
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||
]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last piece of this word
|
||||
i = len(pieces) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||
"""Check if all the given tokens are in token symbol table.
|
||||
Args:
|
||||
token_sym_table:
|
||||
Token symbol table that contains all the valid tokens.
|
||||
tokens:
|
||||
A list of tokens.
|
||||
Returns:
|
||||
Return True if there is any token not in the token_sym_table,
|
||||
otherwise False.
|
||||
"""
|
||||
for tok in tokens:
|
||||
if tok not in token_sym_table:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
token_sym_table: Dict[str, int], words: List[str]
|
||||
) -> Lexicon:
|
||||
"""Generate a lexicon from a word list and token_sym_table.
|
||||
Args:
|
||||
token_sym_table:
|
||||
Token symbol table that mapping token to token ids.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
Returns:
|
||||
Return a dict whose keys are words and values are the corresponding
|
||||
tokens.
|
||||
"""
|
||||
lexicon = []
|
||||
for word in words:
|
||||
chars = list(word.strip(" \t"))
|
||||
if contain_oov(token_sym_table, chars):
|
||||
continue
|
||||
lexicon.append((word, chars))
|
||||
|
||||
# The OOV word is <UNK>
|
||||
lexicon.append(("<UNK>", ["<unk>"]))
|
||||
return lexicon
|
||||
|
||||
|
||||
def generate_tokens(text_file: str) -> Dict[str, int]:
|
||||
"""Generate tokens from the given text file.
|
||||
Args:
|
||||
text_file:
|
||||
A file that contains text lines to generate tokens.
|
||||
Returns:
|
||||
Return a dict whose keys are tokens and values are token ids ranged
|
||||
from 0 to len(keys) - 1.
|
||||
"""
|
||||
tokens: Dict[str, int] = dict()
|
||||
tokens["<blk>"] = 0
|
||||
tokens["<sos/eos>"] = 1
|
||||
tokens["<unk>"] = 2
|
||||
whitespace = re.compile(r"([ \t\r\n]+)")
|
||||
with open(text_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = re.sub(whitespace, "", line)
|
||||
tokens_list = list(line)
|
||||
for token in tokens_list:
|
||||
if token not in tokens:
|
||||
tokens[token] = len(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
|
||||
args = parser.parse_args()
|
||||
|
||||
lang_dir = Path(args.lang_dir)
|
||||
text_file = lang_dir / "text"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
token_sym_table = generate_tokens(text_file)
|
||||
|
||||
lexicon = generate_lexicon(token_sym_table, words)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
next_token_id = max(token_sym_table.values()) + 1
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table[disambig] = next_token_id
|
||||
next_token_id += 1
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
260
egs/wenetspeech/ASR/local/prepare_pinyin.py
Executable file
260
egs/wenetspeech/ASR/local/prepare_pinyin.py
Executable file
@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input `lang_dir`, which should contain::
|
||||
- lang_dir/text,
|
||||
- lang_dir/words.txt
|
||||
and generates the following files in the directory `lang_dir`:
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
- L.pt
|
||||
- L_disambig.pt
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from prepare_lang import (
|
||||
Lexicon,
|
||||
add_disambig_symbols,
|
||||
add_self_loops,
|
||||
write_lexicon,
|
||||
write_mapping,
|
||||
)
|
||||
from pypinyin import pinyin
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format).
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||
|
||||
arcs = []
|
||||
|
||||
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||
assert token2id["<blk>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
|
||||
for word, pieces in lexicon:
|
||||
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [
|
||||
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||
]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last piece of this word
|
||||
i = len(pieces) - 1
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||
"""Check if all the given tokens are in token symbol table.
|
||||
Args:
|
||||
token_sym_table:
|
||||
Token symbol table that contains all the valid tokens.
|
||||
tokens:
|
||||
A list of tokens.
|
||||
Returns:
|
||||
Return True if there is any token not in the token_sym_table,
|
||||
otherwise False.
|
||||
"""
|
||||
for tok in tokens:
|
||||
if tok not in token_sym_table:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def generate_lexicon(
|
||||
token_sym_table: Dict[str, int], words: List[str]
|
||||
) -> Lexicon:
|
||||
"""Generate a lexicon from a word list and token_sym_table.
|
||||
Args:
|
||||
token_sym_table:
|
||||
Token symbol table that mapping token to token ids.
|
||||
words:
|
||||
A list of strings representing words.
|
||||
Returns:
|
||||
Return a dict whose keys are words and values are the corresponding
|
||||
tokens.
|
||||
"""
|
||||
lexicon = []
|
||||
for i in tqdm(range(len(words))):
|
||||
word = words[i]
|
||||
tokens = []
|
||||
pinyins = pinyin(word.strip(" \t"))
|
||||
for pinyin_one in pinyins:
|
||||
if pinyin_one[0].isupper():
|
||||
tokens.extend(list(pinyin_one[0]))
|
||||
else:
|
||||
tokens.append(pinyin_one[0])
|
||||
if contain_oov(token_sym_table, tokens):
|
||||
continue
|
||||
lexicon.append((word, tokens))
|
||||
|
||||
# The OOV word is <UNK>
|
||||
lexicon.append(("<UNK>", ["<unk>"]))
|
||||
return lexicon
|
||||
|
||||
|
||||
def generate_tokens(words: List[str]) -> Dict[str, int]:
|
||||
"""Generate tokens from the given text file.
|
||||
Args:
|
||||
words:
|
||||
The list of words after removing <eps>, !SIL, and so on.
|
||||
Returns:
|
||||
Return a dict whose keys are tokens and values are token ids ranged
|
||||
from 0 to len(keys) - 1.
|
||||
"""
|
||||
tokens: Dict[str, int] = dict()
|
||||
tokens["<blk>"] = 0
|
||||
tokens["<sos/eos>"] = 1
|
||||
tokens["<unk>"] = 2
|
||||
for i in tqdm(range(len(words))):
|
||||
word = words[i]
|
||||
pinyins_list = pinyin(word)
|
||||
for pinyin_one in pinyins_list:
|
||||
if pinyin_one[0].isupper():
|
||||
tokens_list = list(pinyin_one[0])
|
||||
else:
|
||||
tokens_list = pinyin_one
|
||||
for token in tokens_list:
|
||||
if token not in tokens:
|
||||
tokens[token] = len(tokens)
|
||||
tokens = sorted(tokens.items(), key=lambda item: item[0])
|
||||
tokens = dict(tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
|
||||
args = parser.parse_args()
|
||||
|
||||
lang_dir = Path(args.lang_dir)
|
||||
words_file = lang_dir / "words.txt"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(words_file)
|
||||
|
||||
words = word_sym_table.symbols
|
||||
|
||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
|
||||
for w in excluded:
|
||||
if w in words:
|
||||
words.remove(w)
|
||||
|
||||
token_sym_table = generate_tokens(words)
|
||||
|
||||
lexicon = generate_lexicon(token_sym_table, words)
|
||||
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
next_token_id = max(token_sym_table.values()) + 1
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in token_sym_table
|
||||
token_sym_table[disambig] = next_token_id
|
||||
next_token_id += 1
|
||||
|
||||
word_sym_table.add("#0")
|
||||
word_sym_table.add("<s>")
|
||||
word_sym_table.add("</s>")
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||
|
||||
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
lexicon_disambig,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
84
egs/wenetspeech/ASR/local/prepare_words.py
Normal file
84
egs/wenetspeech/ASR/local/prepare_words.py
Normal file
@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input words.txt without ids:
|
||||
- words_no_ids.txt
|
||||
and generates the new words.txt with related ids.
|
||||
- words.txt
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare words.txt",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-file",
|
||||
default="data/lang_char/words_no_ids.txt",
|
||||
type=str,
|
||||
help="the words file without ids for WenetSpeech",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
default="data/lang_char/words.txt",
|
||||
type=str,
|
||||
help="the words file with ids for WenetSpeech",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
input_file = args.input_file
|
||||
output_file = args.output_file
|
||||
|
||||
f = open(input_file, "r", encoding="utf-8")
|
||||
lines = f.readlines()
|
||||
new_lines = []
|
||||
add_words = ["<eps> 0", "!SIL 1", "<SPOKEN_NOISE> 2", "<UNK> 3"]
|
||||
new_lines.extend(add_words)
|
||||
|
||||
logging.info("Starting reading the input file")
|
||||
for i in tqdm(range(len(lines))):
|
||||
x = lines[i]
|
||||
idx = 4 + i
|
||||
new_line = str(x.strip("\n")) + " " + str(idx)
|
||||
new_lines.append(new_line)
|
||||
|
||||
logging.info("Starting writing the words.txt")
|
||||
f_out = open(output_file, "w", encoding="utf-8")
|
||||
for line in new_lines:
|
||||
f_out.write(line)
|
||||
f_out.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
83
egs/wenetspeech/ASR/local/text2segments.py
Normal file
83
egs/wenetspeech/ASR/local/text2segments.py
Normal file
@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input "text", which refers to the transcript file for
|
||||
WenetSpeech:
|
||||
- text
|
||||
and generates the output file text_word_segmentation which is implemented
|
||||
with word segmenting:
|
||||
- text_words_segmentation
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import jieba
|
||||
from tqdm import tqdm
|
||||
|
||||
jieba.enable_paddle()
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Chinese Word Segmentation for text",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
default="data/lang_char/text",
|
||||
type=str,
|
||||
help="the input text file for WenetSpeech",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="data/lang_char/text_words_segmentation",
|
||||
type=str,
|
||||
help="the text implemented with words segmenting for WenetSpeech",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
input_file = args.input
|
||||
output_file = args.output
|
||||
|
||||
f = open(input_file, "r", encoding="utf-8")
|
||||
lines = f.readlines()
|
||||
new_lines = []
|
||||
for i in tqdm(range(len(lines))):
|
||||
x = lines[i].rstrip()
|
||||
seg_list = jieba.cut(x, use_paddle=True)
|
||||
new_line = " ".join(seg_list)
|
||||
new_lines.append(new_line)
|
||||
|
||||
f_new = open(output_file, "w", encoding="utf-8")
|
||||
for line in new_lines:
|
||||
f_new.write(line)
|
||||
f_new.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -188,7 +188,7 @@ def main():
|
||||
|
||||
a_chars = [z.replace(" ", args.space) for z in a_flat]
|
||||
|
||||
print(" ".join(a_chars))
|
||||
print("".join(a_chars))
|
||||
line = f.readline()
|
||||
|
||||
|
||||
|
@ -117,9 +117,9 @@ fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Combine features for L"
|
||||
if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/cuts_L.jsonl.gz
|
||||
if [ ! -f data/fbank/cuts_L_50.jsonl.gz ]; then
|
||||
pieces=$(find data/fbank/L_split_50 -name "cuts_L.*.jsonl.gz")
|
||||
lhotse combine $pieces data/fbank/cuts_L_50.jsonl.gz
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -134,120 +134,50 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
lang_char_dir=data/lang_char
|
||||
mkdir -p $lang_char_dir
|
||||
|
||||
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||
# Prepare text.
|
||||
if [ ! -f $lang_char_dir/text ]; then
|
||||
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||
fi
|
||||
|
||||
cat $lang_char_dir/text | sed 's/ /\n/g' \
|
||||
| sort -u | sed '/^$/d' > $lang_char_dir/words.txt
|
||||
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
|
||||
cat - $lang_char_dir/words.txt | sort | uniq | awk '
|
||||
BEGIN {
|
||||
print "<eps> 0";
|
||||
}
|
||||
{
|
||||
if ($1 == "<s>") {
|
||||
print "<s> is in the vocabulary!" | "cat 1>&2"
|
||||
exit 1;
|
||||
}
|
||||
if ($1 == "</s>") {
|
||||
print "</s> is in the vocabulary!" | "cat 1>&2"
|
||||
exit 1;
|
||||
}
|
||||
printf("%s %d\n", $1, NR);
|
||||
}
|
||||
END {
|
||||
printf("#0 %d\n", NR+1);
|
||||
printf("<s> %d\n", NR+2);
|
||||
printf("</s> %d\n", NR+3);
|
||||
}' > $lang_char_dir/words || exit 1;
|
||||
# The implementation of chinese word segmentation for text,
|
||||
# and it will take about 15 minutes.
|
||||
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
|
||||
python ./local/text2segments.py \
|
||||
--input $lang_char_dir/text \
|
||||
--output $lang_char_dir/text_words_segmentation
|
||||
fi
|
||||
|
||||
mv $lang_char_dir/words $lang_char_dir/words.txt
|
||||
cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \
|
||||
| sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt
|
||||
|
||||
if [ ! -f $lang_char_dir/words.txt ]; then
|
||||
python ./local/prepare_words.py \
|
||||
--input-file $lang_char_dir/words_no_ids.txt \
|
||||
--output-file $lang_char_dir/words.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||
log "Stage 10: Prepare pinyin based lang"
|
||||
lang_pinyin_dir=data/lang_pinyin
|
||||
mkdir -p $lang_pinyin_dir
|
||||
|
||||
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "pinyin" > $lang_pinyin_dir/text
|
||||
|
||||
cat $lang_pinyin_dir/text | sed 's/ /\n/g' \
|
||||
| sort -u | sed '/^$/d' > $lang_pinyin_dir/words.txt
|
||||
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
|
||||
cat - $lang_pinyin_dir/words.txt | sort | uniq | awk '
|
||||
BEGIN {
|
||||
print "<eps> 0";
|
||||
}
|
||||
{
|
||||
if ($1 == "<s>") {
|
||||
print "<s> is in the vocabulary!" | "cat 1>&2"
|
||||
exit 1;
|
||||
}
|
||||
if ($1 == "</s>") {
|
||||
print "</s> is in the vocabulary!" | "cat 1>&2"
|
||||
exit 1;
|
||||
}
|
||||
printf("%s %d\n", $1, NR);
|
||||
}
|
||||
END {
|
||||
printf("#0 %d\n", NR+1);
|
||||
printf("<s> %d\n", NR+2);
|
||||
printf("</s> %d\n", NR+3);
|
||||
}' > $lang_pinyin_dir/words || exit 1;
|
||||
|
||||
mv $lang_pinyin_dir/words $lang_pinyin_dir/words.txt
|
||||
log "Stage 10: Prepare char based L_disambig.pt"
|
||||
if [ ! -f data/lang_char/L_disambig.pt ]; then
|
||||
python ./local/prepare_char.py \
|
||||
--lang-dir data/lang_char
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
log "Stage 11: Prepare lazy_pinyin based lang"
|
||||
lang_lazy_pinyin_dir=data/lang_lazy_pinyin
|
||||
mkdir -p $lang_lazy_pinyin_dir
|
||||
log "Stage 11: Prepare pinyin based L_disambig.pt"
|
||||
lang_pinyin_dir=data/lang_pinyin
|
||||
mkdir -p $lang_pinyin_dir
|
||||
|
||||
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
||||
| jq '.text' | sed 's/"//g' \
|
||||
| ./local/text2token.py -t "lazy_pinyin" > $lang_lazy_pinyin_dir/text
|
||||
|
||||
cat $lang_lazy_pinyin_dir/text | sed 's/ /\n/g' \
|
||||
| sort -u | sed '/^$/d' > $lang_lazy_pinyin_dir/words.txt
|
||||
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
|
||||
cat - $lang_lazy_pinyin_dir/words.txt | sort | uniq | awk '
|
||||
BEGIN {
|
||||
print "<eps> 0";
|
||||
}
|
||||
{
|
||||
if ($1 == "<s>") {
|
||||
print "<s> is in the vocabulary!" | "cat 1>&2"
|
||||
exit 1;
|
||||
}
|
||||
if ($1 == "</s>") {
|
||||
print "</s> is in the vocabulary!" | "cat 1>&2"
|
||||
exit 1;
|
||||
}
|
||||
printf("%s %d\n", $1, NR);
|
||||
}
|
||||
END {
|
||||
printf("#0 %d\n", NR+1);
|
||||
printf("<s> %d\n", NR+2);
|
||||
printf("</s> %d\n", NR+3);
|
||||
}' > $lang_lazy_pinyin_dir/words || exit 1;
|
||||
|
||||
mv $lang_lazy_pinyin_dir/words $lang_lazy_pinyin_dir/words.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "Stage 12: Prepare L_disambig.pt"
|
||||
if [ ! -f data/lang_char/L_disambig.pt ]; then
|
||||
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_char
|
||||
fi
|
||||
cp -r data/lang_char/words.txt $lang_pinyin_dir/
|
||||
cp -r data/lang_char/text $lang_pinyin_dir/
|
||||
cp -r data/lang_char/text_words_segmentation $lang_pinyin_dir/
|
||||
|
||||
if [ ! -f data/lang_pinyin/L_disambig.pt ]; then
|
||||
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_pinyin
|
||||
fi
|
||||
|
||||
if [ ! -f data/lang_lazy_pinyin/L_disambig.pt ]; then
|
||||
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_lazy_pinyin
|
||||
python ./local/prepare_pinyin.py \
|
||||
--lang-dir data/lang_pinyin
|
||||
fi
|
||||
fi
|
||||
|
@ -375,11 +375,15 @@ class WenetSpeechAsrDataModule:
|
||||
if self.args.lazy_load:
|
||||
logging.info("use lazy cuts")
|
||||
cuts_train = CutSet.from_jsonl_lazy(
|
||||
self.args.manifest_dir / "cuts_L.jsonl.gz"
|
||||
self.args.manifest_dir
|
||||
/ "cuts_L_50_pieces.jsonl.gz"
|
||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||
)
|
||||
else:
|
||||
cuts_train = CutSet.from_file(
|
||||
self.args.manifest_dir / "cuts_L.jsonl.gz"
|
||||
self.args.manifest_dir
|
||||
/ "cuts_L_50_pieces.jsonl.gz"
|
||||
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||
)
|
||||
return cuts_train
|
||||
|
||||
|
@ -25,8 +25,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--max-duration 300
|
||||
--lang-dir data/lang_char \
|
||||
--exp-dir pruned_transducer_stateless/exp-char \
|
||||
--token-type char \
|
||||
--max-duration 200
|
||||
"""
|
||||
|
||||
|
||||
@ -44,8 +46,8 @@ from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.utils import fix_random_seed
|
||||
from local.text2token import token2id
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -59,6 +61,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.pinyin_graph_compiler import PinyinCtcTrainingGraphCompiler
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
|
||||
@ -108,7 +111,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless/exp",
|
||||
default="pruned_transducer_stateless_pinyin/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -118,7 +121,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_lazy_pinyin",
|
||||
default="data/lang_char",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
@ -128,10 +131,9 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--token-type",
|
||||
type=str,
|
||||
default="lazy_pinyin",
|
||||
help="""The token type
|
||||
It refers to the token type for modeling, such as
|
||||
char, pinyin, lazy_pinyin.
|
||||
default="char",
|
||||
help="""The type of token
|
||||
It must be in ["char", "pinyin", "lazy_pinyin"].
|
||||
""",
|
||||
)
|
||||
|
||||
@ -435,16 +437,12 @@ def compute_loss(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = ""
|
||||
if params.token_type == "char":
|
||||
y = graph_compiler.texts_to_ids(texts)
|
||||
|
||||
y = graph_compiler.texts_to_ids(texts)
|
||||
if type(y) == list:
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
else:
|
||||
y = token2id(
|
||||
texts=texts,
|
||||
token_table=graph_compiler.token_table,
|
||||
token_type=params.token_type,
|
||||
)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
y = y.to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
@ -635,9 +633,19 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon, device=device, oov="<unk>"
|
||||
)
|
||||
graph_compiler = ""
|
||||
if params.token_type == "char":
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
if params.token_type == "pinyin":
|
||||
graph_compiler = PinyinCtcTrainingGraphCompiler(
|
||||
lang_dir=params.lang_dir,
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
@ -672,6 +680,23 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = wenetspeech.train_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 15.0 seconds
|
||||
# You can get the statistics by local/display_manifest_statistics.py.
|
||||
return 1.0 <= c.duration <= 15.0
|
||||
|
||||
def text_to_words(c: Cut):
|
||||
# Convert text to words_segments.
|
||||
text = c.supervisions[0].text
|
||||
text = text.strip("\n").strip("\t")
|
||||
words_cut = graph_compiler.text2words[text]
|
||||
c.supervisions[0].text = words_cut
|
||||
return c
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
if params.token_type == "pinyin":
|
||||
train_cuts = train_cuts.map(text_to_words)
|
||||
|
||||
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
||||
valid_cuts = wenetspeech.valid_cuts()
|
||||
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
|
||||
|
219
icefall/pinyin_graph_compiler.py
Normal file
219
icefall/pinyin_graph_compiler.py
Normal file
@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from icefall.lexicon import Lexicon, read_lexicon
|
||||
|
||||
|
||||
class PinyinCtcTrainingGraphCompiler(object):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
lexicon: Lexicon,
|
||||
device: torch.device,
|
||||
sos_token: str = "<sos/eos>",
|
||||
eos_token: str = "<sos/eos>",
|
||||
oov: str = "<unk>",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lexicon:
|
||||
It is built from `data/lang_char/lexicon.txt`.
|
||||
device:
|
||||
The device to use for operations compiling transcripts to FSAs.
|
||||
oov:
|
||||
Out of vocabulary token. When a word(token) in the transcript
|
||||
does not exist in the token list, it is replaced with `oov`.
|
||||
"""
|
||||
|
||||
assert oov in lexicon.token_table
|
||||
|
||||
self.lang_dir = lang_dir
|
||||
self.oov_id = lexicon.token_table[oov]
|
||||
self.token_table = lexicon.token_table
|
||||
|
||||
self.device = device
|
||||
|
||||
self.sos_id = self.token_table[sos_token]
|
||||
self.eos_id = self.token_table[eos_token]
|
||||
|
||||
self.word_table = lexicon.word_table
|
||||
self.token_table = lexicon.token_table
|
||||
|
||||
self.text2words = convert_text_to_word_segments(
|
||||
text_filename=self.lang_dir / "text",
|
||||
words_segments_filename=self.lang_dir / "text_words_segmentation",
|
||||
)
|
||||
self.ragged_lexicon = convert_lexicon_to_ragged(
|
||||
filename=self.lang_dir / "lexicon.txt",
|
||||
word_table=self.word_table,
|
||||
token_table=self.token_table,
|
||||
)
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of pinyin-based token IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings.
|
||||
An example containing two strings is given below:
|
||||
|
||||
['你好中国', '北京欢迎您']
|
||||
Returns:
|
||||
Return a list-of-list of pinyin-based token IDs.
|
||||
"""
|
||||
word_ids_list = []
|
||||
for i in range(len(texts)):
|
||||
word_ids = []
|
||||
text = texts[i].strip("\n").strip("\t")
|
||||
for word in text.split(" "):
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
word_ids.append(self.oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
|
||||
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
|
||||
ans = self.ragged_lexicon.index(ragged_indexes)
|
||||
ans = ans.remove_axis(ans.num_axes - 2)
|
||||
|
||||
return ans
|
||||
|
||||
def compile(
|
||||
self,
|
||||
token_ids: List[List[int]],
|
||||
modified: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Build a ctc graph from a list-of-list token IDs.
|
||||
|
||||
Args:
|
||||
piece_ids:
|
||||
It is a list-of-list integer IDs.
|
||||
modified:
|
||||
See :func:`k2.ctc_graph` for its meaning.
|
||||
Return:
|
||||
Return an FsaVec, which is the result of composing a
|
||||
CTC topology with linear FSAs constructed from the given
|
||||
piece IDs.
|
||||
"""
|
||||
return k2.ctc_graph(token_ids, modified=modified, device=self.device)
|
||||
|
||||
|
||||
def convert_lexicon_to_ragged(
|
||||
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
|
||||
) -> k2.RaggedTensor:
|
||||
"""Read a lexicon and convert lexicon to a ragged tensor.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Path to the lexicon file.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
token_table:
|
||||
The token symbol table.
|
||||
Returns:
|
||||
A k2 ragged tensor with two axes [word][token].
|
||||
"""
|
||||
num_words = len(word_table.symbols)
|
||||
excluded_words = [
|
||||
"<eps>",
|
||||
"!SIL",
|
||||
"<SPOKEN_NOISE>",
|
||||
"<UNK>",
|
||||
"#0",
|
||||
"<s>",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
row_splits = [0]
|
||||
token_ids_list = []
|
||||
|
||||
lexicon_tmp = read_lexicon(filename)
|
||||
lexicon = dict(lexicon_tmp)
|
||||
if len(lexicon_tmp) != len(lexicon):
|
||||
raise RuntimeError(
|
||||
"It's assumed that each word has a unique pronunciation"
|
||||
)
|
||||
|
||||
for i in range(num_words):
|
||||
w = word_table[i]
|
||||
if w in excluded_words:
|
||||
row_splits.append(row_splits[-1])
|
||||
continue
|
||||
tokens = lexicon[w]
|
||||
token_ids = [token_table[k] for k in tokens]
|
||||
|
||||
row_splits.append(row_splits[-1] + len(token_ids))
|
||||
token_ids_list.extend(token_ids)
|
||||
|
||||
cached_tot_size = row_splits[-1]
|
||||
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||
|
||||
shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits,
|
||||
None,
|
||||
cached_tot_size,
|
||||
)
|
||||
values = torch.tensor(token_ids_list, dtype=torch.int32)
|
||||
|
||||
return k2.RaggedTensor(shape, values)
|
||||
|
||||
|
||||
def convert_text_to_word_segments(
|
||||
text_filename: str, words_segments_filename: str
|
||||
) -> Dict[str, str]:
|
||||
"""Convert text to word-based segments.
|
||||
|
||||
Args:
|
||||
text_filename:
|
||||
The file for the original transcripts.
|
||||
words_segments_filename:
|
||||
The file after implementing chinese word segmentation
|
||||
for the original transcripts.
|
||||
Returns:
|
||||
A dictionary about text and words_segments.
|
||||
"""
|
||||
text2words = {}
|
||||
|
||||
f_text = open(text_filename, "r", encoding="utf-8")
|
||||
text_lines = f_text.readlines()
|
||||
text_lines = [line.strip("\t") for line in text_lines]
|
||||
|
||||
f_words = open(words_segments_filename, "r", encoding="utf-8")
|
||||
words_lines = f_words.readlines()
|
||||
words_lines = [line.strip("\t") for line in words_lines]
|
||||
|
||||
if len(text_lines) != len(words_lines):
|
||||
raise RuntimeError(
|
||||
"The lengths of text and words_segments should be equal."
|
||||
)
|
||||
|
||||
for i in tqdm(range(len(text_lines))):
|
||||
text = text_lines[i].strip(" ").strip("\n")
|
||||
words_segments = words_lines[i].strip(" ").strip("\n")
|
||||
text2words[text] = words_segments
|
||||
|
||||
return text2words
|
Loading…
x
Reference in New Issue
Block a user