mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
add and update some files
This commit is contained in:
parent
5242275b8a
commit
d8907769ee
135
egs/wenetspeech/ASR/local/display_manifest_statistics.py
Normal file
135
egs/wenetspeech/ASR/local/display_manifest_statistics.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file displays duration statistics of utterances in a manifest.
|
||||||
|
You can use the displayed value to choose minimum/maximum duration
|
||||||
|
to remove short and long utterances during the training.
|
||||||
|
See the function `remove_short_and_long_utt()`
|
||||||
|
in ../../../librispeech/ASR/transducer/train.py
|
||||||
|
for usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
from lhotse import load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
paths = [ # "./data/fbank/cuts_L_100_pieces.jsonl.gz",
|
||||||
|
"./data/fbank/cuts_L_50_pieces.jsonl.gz",
|
||||||
|
# "./data/fbank/cuts_DEV.jsonl.gz",
|
||||||
|
# "./data/fbank/cuts_TEST_NET.jsonl.gz",
|
||||||
|
# "./data/fbank/cuts_TEST_MEETING.jsonl.gz"
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
print(f"Starting display the statistics for {path}")
|
||||||
|
cuts = load_manifest(path)
|
||||||
|
cuts.describe()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Starting display the statistics for ./data/fbank/cuts_L_50_pieces.jsonl.gz
|
||||||
|
Cuts count: 2241476
|
||||||
|
Total duration (hours): 1475.0
|
||||||
|
Speech duration (hours): 1475.0 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 2.4
|
||||||
|
std 1.6
|
||||||
|
min 0.3
|
||||||
|
25% 1.3
|
||||||
|
50% 2.0
|
||||||
|
75% 2.9
|
||||||
|
99% 8.2
|
||||||
|
99.5% 9.3
|
||||||
|
99.9% 13.5
|
||||||
|
max 87.0
|
||||||
|
|
||||||
|
Starting display the statistics for ./data/fbank/cuts_L_100_pieces.jsonl.gz
|
||||||
|
Cuts count: 4929619
|
||||||
|
Total duration (hours): 3361.1
|
||||||
|
Speech duration (hours): 3361.1 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 2.5
|
||||||
|
std 1.7
|
||||||
|
min 0.3
|
||||||
|
25% 1.4
|
||||||
|
50% 2.0
|
||||||
|
75% 3.0
|
||||||
|
99% 8.1
|
||||||
|
99.5% 8.8
|
||||||
|
99.9% 14.7
|
||||||
|
max 87.0
|
||||||
|
|
||||||
|
Starting display the statistics for ./data/fbank/cuts_DEV.jsonl.gz
|
||||||
|
Cuts count: 13825
|
||||||
|
Total duration (hours): 20.0
|
||||||
|
Speech duration (hours): 20.0 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 5.2
|
||||||
|
std 2.2
|
||||||
|
min 1.0
|
||||||
|
25% 3.3
|
||||||
|
50% 4.9
|
||||||
|
75% 7.0
|
||||||
|
99% 9.6
|
||||||
|
99.5% 9.8
|
||||||
|
99.9% 10.0
|
||||||
|
max 10.0
|
||||||
|
|
||||||
|
Starting display the statistics for ./data/fbank/cuts_TEST_NET.jsonl.gz
|
||||||
|
Cuts count: 24774
|
||||||
|
Total duration (hours): 23.1
|
||||||
|
Speech duration (hours): 23.1 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 3.4
|
||||||
|
std 2.6
|
||||||
|
min 0.1
|
||||||
|
25% 1.4
|
||||||
|
50% 2.4
|
||||||
|
75% 4.8
|
||||||
|
99% 13.1
|
||||||
|
99.5% 14.5
|
||||||
|
99.9% 18.5
|
||||||
|
max 33.3
|
||||||
|
|
||||||
|
Starting display the statistics for ./data/fbank/cuts_TEST_MEETING.jsonl.gz
|
||||||
|
Cuts count: 8370
|
||||||
|
Total duration (hours): 15.2
|
||||||
|
Speech duration (hours): 15.2 (100.0%)
|
||||||
|
***
|
||||||
|
Duration statistics (seconds):
|
||||||
|
mean 6.5
|
||||||
|
std 3.5
|
||||||
|
min 0.8
|
||||||
|
25% 3.7
|
||||||
|
50% 5.8
|
||||||
|
75% 8.8
|
||||||
|
99% 15.2
|
||||||
|
99.5% 16.0
|
||||||
|
99.9% 18.8
|
||||||
|
max 24.6
|
||||||
|
|
||||||
|
"""
|
246
egs/wenetspeech/ASR/local/prepare_char.py
Executable file
246
egs/wenetspeech/ASR/local/prepare_char.py
Executable file
@ -0,0 +1,246 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang,
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input `lang_dir`, which should contain::
|
||||||
|
- lang_dir/text,
|
||||||
|
- lang_dir/words.txt
|
||||||
|
and generates the following files in the directory `lang_dir`:
|
||||||
|
- lexicon.txt
|
||||||
|
- lexicon_disambig.txt
|
||||||
|
- L.pt
|
||||||
|
- L_disambig.pt
|
||||||
|
- tokens.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from prepare_lang import (
|
||||||
|
Lexicon,
|
||||||
|
add_disambig_symbols,
|
||||||
|
add_self_loops,
|
||||||
|
write_lexicon,
|
||||||
|
write_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def lexicon_to_fst_no_sil(
|
||||||
|
lexicon: Lexicon,
|
||||||
|
token2id: Dict[str, int],
|
||||||
|
word2id: Dict[str, int],
|
||||||
|
need_self_loops: bool = False,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Convert a lexicon to an FST (in k2 format).
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon. See also :func:`read_lexicon`
|
||||||
|
token2id:
|
||||||
|
A dict mapping tokens to IDs.
|
||||||
|
word2id:
|
||||||
|
A dict mapping words to IDs.
|
||||||
|
need_self_loops:
|
||||||
|
If True, add self-loop to states with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state. The input label for this
|
||||||
|
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||||
|
Returns:
|
||||||
|
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||||
|
"""
|
||||||
|
loop_state = 0 # words enter and leave from here
|
||||||
|
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||||
|
|
||||||
|
arcs = []
|
||||||
|
|
||||||
|
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||||
|
assert token2id["<blk>"] == 0
|
||||||
|
assert word2id["<eps>"] == 0
|
||||||
|
|
||||||
|
eps = 0
|
||||||
|
|
||||||
|
for word, pieces in lexicon:
|
||||||
|
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||||
|
cur_state = loop_state
|
||||||
|
|
||||||
|
word = word2id[word]
|
||||||
|
pieces = [
|
||||||
|
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(len(pieces) - 1):
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||||
|
|
||||||
|
cur_state = next_state
|
||||||
|
next_state += 1
|
||||||
|
|
||||||
|
# now for the last piece of this word
|
||||||
|
i = len(pieces) - 1
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||||
|
|
||||||
|
if need_self_loops:
|
||||||
|
disambig_token = token2id["#0"]
|
||||||
|
disambig_word = word2id["#0"]
|
||||||
|
arcs = add_self_loops(
|
||||||
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = next_state
|
||||||
|
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||||
|
arcs.append([final_state])
|
||||||
|
|
||||||
|
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||||
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
|
arcs = "\n".join(arcs)
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
|
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||||
|
"""Check if all the given tokens are in token symbol table.
|
||||||
|
Args:
|
||||||
|
token_sym_table:
|
||||||
|
Token symbol table that contains all the valid tokens.
|
||||||
|
tokens:
|
||||||
|
A list of tokens.
|
||||||
|
Returns:
|
||||||
|
Return True if there is any token not in the token_sym_table,
|
||||||
|
otherwise False.
|
||||||
|
"""
|
||||||
|
for tok in tokens:
|
||||||
|
if tok not in token_sym_table:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lexicon(
|
||||||
|
token_sym_table: Dict[str, int], words: List[str]
|
||||||
|
) -> Lexicon:
|
||||||
|
"""Generate a lexicon from a word list and token_sym_table.
|
||||||
|
Args:
|
||||||
|
token_sym_table:
|
||||||
|
Token symbol table that mapping token to token ids.
|
||||||
|
words:
|
||||||
|
A list of strings representing words.
|
||||||
|
Returns:
|
||||||
|
Return a dict whose keys are words and values are the corresponding
|
||||||
|
tokens.
|
||||||
|
"""
|
||||||
|
lexicon = []
|
||||||
|
for word in words:
|
||||||
|
chars = list(word.strip(" \t"))
|
||||||
|
if contain_oov(token_sym_table, chars):
|
||||||
|
continue
|
||||||
|
lexicon.append((word, chars))
|
||||||
|
|
||||||
|
# The OOV word is <UNK>
|
||||||
|
lexicon.append(("<UNK>", ["<unk>"]))
|
||||||
|
return lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tokens(text_file: str) -> Dict[str, int]:
|
||||||
|
"""Generate tokens from the given text file.
|
||||||
|
Args:
|
||||||
|
text_file:
|
||||||
|
A file that contains text lines to generate tokens.
|
||||||
|
Returns:
|
||||||
|
Return a dict whose keys are tokens and values are token ids ranged
|
||||||
|
from 0 to len(keys) - 1.
|
||||||
|
"""
|
||||||
|
tokens: Dict[str, int] = dict()
|
||||||
|
tokens["<blk>"] = 0
|
||||||
|
tokens["<sos/eos>"] = 1
|
||||||
|
tokens["<unk>"] = 2
|
||||||
|
whitespace = re.compile(r"([ \t\r\n]+)")
|
||||||
|
with open(text_file, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = re.sub(whitespace, "", line)
|
||||||
|
tokens_list = list(line)
|
||||||
|
for token in tokens_list:
|
||||||
|
if token not in tokens:
|
||||||
|
tokens[token] = len(tokens)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
text_file = lang_dir / "text"
|
||||||
|
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
words = word_sym_table.symbols
|
||||||
|
|
||||||
|
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
|
||||||
|
for w in excluded:
|
||||||
|
if w in words:
|
||||||
|
words.remove(w)
|
||||||
|
|
||||||
|
token_sym_table = generate_tokens(text_file)
|
||||||
|
|
||||||
|
lexicon = generate_lexicon(token_sym_table, words)
|
||||||
|
|
||||||
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
|
next_token_id = max(token_sym_table.values()) + 1
|
||||||
|
for i in range(max_disambig + 1):
|
||||||
|
disambig = f"#{i}"
|
||||||
|
assert disambig not in token_sym_table
|
||||||
|
token_sym_table[disambig] = next_token_id
|
||||||
|
next_token_id += 1
|
||||||
|
|
||||||
|
word_sym_table.add("#0")
|
||||||
|
word_sym_table.add("<s>")
|
||||||
|
word_sym_table.add("</s>")
|
||||||
|
|
||||||
|
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||||
|
|
||||||
|
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||||
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
|
L = lexicon_to_fst_no_sil(
|
||||||
|
lexicon,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
L_disambig = lexicon_to_fst_no_sil(
|
||||||
|
lexicon_disambig,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
|
need_self_loops=True,
|
||||||
|
)
|
||||||
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
260
egs/wenetspeech/ASR/local/prepare_pinyin.py
Executable file
260
egs/wenetspeech/ASR/local/prepare_pinyin.py
Executable file
@ -0,0 +1,260 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang,
|
||||||
|
# Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input `lang_dir`, which should contain::
|
||||||
|
- lang_dir/text,
|
||||||
|
- lang_dir/words.txt
|
||||||
|
and generates the following files in the directory `lang_dir`:
|
||||||
|
- lexicon.txt
|
||||||
|
- lexicon_disambig.txt
|
||||||
|
- L.pt
|
||||||
|
- L_disambig.pt
|
||||||
|
- tokens.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from prepare_lang import (
|
||||||
|
Lexicon,
|
||||||
|
add_disambig_symbols,
|
||||||
|
add_self_loops,
|
||||||
|
write_lexicon,
|
||||||
|
write_mapping,
|
||||||
|
)
|
||||||
|
from pypinyin import pinyin
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def lexicon_to_fst_no_sil(
|
||||||
|
lexicon: Lexicon,
|
||||||
|
token2id: Dict[str, int],
|
||||||
|
word2id: Dict[str, int],
|
||||||
|
need_self_loops: bool = False,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Convert a lexicon to an FST (in k2 format).
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
The input lexicon. See also :func:`read_lexicon`
|
||||||
|
token2id:
|
||||||
|
A dict mapping tokens to IDs.
|
||||||
|
word2id:
|
||||||
|
A dict mapping words to IDs.
|
||||||
|
need_self_loops:
|
||||||
|
If True, add self-loop to states with non-epsilon output symbols
|
||||||
|
on at least one arc out of the state. The input label for this
|
||||||
|
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||||
|
Returns:
|
||||||
|
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||||
|
"""
|
||||||
|
loop_state = 0 # words enter and leave from here
|
||||||
|
next_state = 1 # the next un-allocated state, will be incremented as we go
|
||||||
|
|
||||||
|
arcs = []
|
||||||
|
|
||||||
|
# The blank symbol <blk> is defined in local/train_bpe_model.py
|
||||||
|
assert token2id["<blk>"] == 0
|
||||||
|
assert word2id["<eps>"] == 0
|
||||||
|
|
||||||
|
eps = 0
|
||||||
|
|
||||||
|
for word, pieces in lexicon:
|
||||||
|
assert len(pieces) > 0, f"{word} has no pronunciations"
|
||||||
|
cur_state = loop_state
|
||||||
|
|
||||||
|
word = word2id[word]
|
||||||
|
pieces = [
|
||||||
|
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(len(pieces) - 1):
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, next_state, pieces[i], w, 0])
|
||||||
|
|
||||||
|
cur_state = next_state
|
||||||
|
next_state += 1
|
||||||
|
|
||||||
|
# now for the last piece of this word
|
||||||
|
i = len(pieces) - 1
|
||||||
|
w = word if i == 0 else eps
|
||||||
|
arcs.append([cur_state, loop_state, pieces[i], w, 0])
|
||||||
|
|
||||||
|
if need_self_loops:
|
||||||
|
disambig_token = token2id["#0"]
|
||||||
|
disambig_word = word2id["#0"]
|
||||||
|
arcs = add_self_loops(
|
||||||
|
arcs,
|
||||||
|
disambig_token=disambig_token,
|
||||||
|
disambig_word=disambig_word,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_state = next_state
|
||||||
|
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||||
|
arcs.append([final_state])
|
||||||
|
|
||||||
|
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||||
|
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||||
|
arcs = [" ".join(arc) for arc in arcs]
|
||||||
|
arcs = "\n".join(arcs)
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||||
|
return fsa
|
||||||
|
|
||||||
|
|
||||||
|
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||||
|
"""Check if all the given tokens are in token symbol table.
|
||||||
|
Args:
|
||||||
|
token_sym_table:
|
||||||
|
Token symbol table that contains all the valid tokens.
|
||||||
|
tokens:
|
||||||
|
A list of tokens.
|
||||||
|
Returns:
|
||||||
|
Return True if there is any token not in the token_sym_table,
|
||||||
|
otherwise False.
|
||||||
|
"""
|
||||||
|
for tok in tokens:
|
||||||
|
if tok not in token_sym_table:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def generate_lexicon(
|
||||||
|
token_sym_table: Dict[str, int], words: List[str]
|
||||||
|
) -> Lexicon:
|
||||||
|
"""Generate a lexicon from a word list and token_sym_table.
|
||||||
|
Args:
|
||||||
|
token_sym_table:
|
||||||
|
Token symbol table that mapping token to token ids.
|
||||||
|
words:
|
||||||
|
A list of strings representing words.
|
||||||
|
Returns:
|
||||||
|
Return a dict whose keys are words and values are the corresponding
|
||||||
|
tokens.
|
||||||
|
"""
|
||||||
|
lexicon = []
|
||||||
|
for i in tqdm(range(len(words))):
|
||||||
|
word = words[i]
|
||||||
|
tokens = []
|
||||||
|
pinyins = pinyin(word.strip(" \t"))
|
||||||
|
for pinyin_one in pinyins:
|
||||||
|
if pinyin_one[0].isupper():
|
||||||
|
tokens.extend(list(pinyin_one[0]))
|
||||||
|
else:
|
||||||
|
tokens.append(pinyin_one[0])
|
||||||
|
if contain_oov(token_sym_table, tokens):
|
||||||
|
continue
|
||||||
|
lexicon.append((word, tokens))
|
||||||
|
|
||||||
|
# The OOV word is <UNK>
|
||||||
|
lexicon.append(("<UNK>", ["<unk>"]))
|
||||||
|
return lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tokens(words: List[str]) -> Dict[str, int]:
|
||||||
|
"""Generate tokens from the given text file.
|
||||||
|
Args:
|
||||||
|
words:
|
||||||
|
The list of words after removing <eps>, !SIL, and so on.
|
||||||
|
Returns:
|
||||||
|
Return a dict whose keys are tokens and values are token ids ranged
|
||||||
|
from 0 to len(keys) - 1.
|
||||||
|
"""
|
||||||
|
tokens: Dict[str, int] = dict()
|
||||||
|
tokens["<blk>"] = 0
|
||||||
|
tokens["<sos/eos>"] = 1
|
||||||
|
tokens["<unk>"] = 2
|
||||||
|
for i in tqdm(range(len(words))):
|
||||||
|
word = words[i]
|
||||||
|
pinyins_list = pinyin(word)
|
||||||
|
for pinyin_one in pinyins_list:
|
||||||
|
if pinyin_one[0].isupper():
|
||||||
|
tokens_list = list(pinyin_one[0])
|
||||||
|
else:
|
||||||
|
tokens_list = pinyin_one
|
||||||
|
for token in tokens_list:
|
||||||
|
if token not in tokens:
|
||||||
|
tokens[token] = len(tokens)
|
||||||
|
tokens = sorted(tokens.items(), key=lambda item: item[0])
|
||||||
|
tokens = dict(tokens)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--lang-dir", type=str, help="The lang directory.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
words_file = lang_dir / "words.txt"
|
||||||
|
|
||||||
|
word_sym_table = k2.SymbolTable.from_file(words_file)
|
||||||
|
|
||||||
|
words = word_sym_table.symbols
|
||||||
|
|
||||||
|
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
|
||||||
|
for w in excluded:
|
||||||
|
if w in words:
|
||||||
|
words.remove(w)
|
||||||
|
|
||||||
|
token_sym_table = generate_tokens(words)
|
||||||
|
|
||||||
|
lexicon = generate_lexicon(token_sym_table, words)
|
||||||
|
|
||||||
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
|
next_token_id = max(token_sym_table.values()) + 1
|
||||||
|
for i in range(max_disambig + 1):
|
||||||
|
disambig = f"#{i}"
|
||||||
|
assert disambig not in token_sym_table
|
||||||
|
token_sym_table[disambig] = next_token_id
|
||||||
|
next_token_id += 1
|
||||||
|
|
||||||
|
word_sym_table.add("#0")
|
||||||
|
word_sym_table.add("<s>")
|
||||||
|
word_sym_table.add("</s>")
|
||||||
|
|
||||||
|
write_mapping(lang_dir / "tokens.txt", token_sym_table)
|
||||||
|
|
||||||
|
write_lexicon(lang_dir / "lexicon.txt", lexicon)
|
||||||
|
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||||
|
|
||||||
|
L = lexicon_to_fst_no_sil(
|
||||||
|
lexicon,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
L_disambig = lexicon_to_fst_no_sil(
|
||||||
|
lexicon_disambig,
|
||||||
|
token2id=token_sym_table,
|
||||||
|
word2id=word_sym_table,
|
||||||
|
need_self_loops=True,
|
||||||
|
)
|
||||||
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
84
egs/wenetspeech/ASR/local/prepare_words.py
Normal file
84
egs/wenetspeech/ASR/local/prepare_words.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input words.txt without ids:
|
||||||
|
- words_no_ids.txt
|
||||||
|
and generates the new words.txt with related ids.
|
||||||
|
- words.txt
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Prepare words.txt",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-file",
|
||||||
|
default="data/lang_char/words_no_ids.txt",
|
||||||
|
type=str,
|
||||||
|
help="the words file without ids for WenetSpeech",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
default="data/lang_char/words.txt",
|
||||||
|
type=str,
|
||||||
|
help="the words file with ids for WenetSpeech",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
input_file = args.input_file
|
||||||
|
output_file = args.output_file
|
||||||
|
|
||||||
|
f = open(input_file, "r", encoding="utf-8")
|
||||||
|
lines = f.readlines()
|
||||||
|
new_lines = []
|
||||||
|
add_words = ["<eps> 0", "!SIL 1", "<SPOKEN_NOISE> 2", "<UNK> 3"]
|
||||||
|
new_lines.extend(add_words)
|
||||||
|
|
||||||
|
logging.info("Starting reading the input file")
|
||||||
|
for i in tqdm(range(len(lines))):
|
||||||
|
x = lines[i]
|
||||||
|
idx = 4 + i
|
||||||
|
new_line = str(x.strip("\n")) + " " + str(idx)
|
||||||
|
new_lines.append(new_line)
|
||||||
|
|
||||||
|
logging.info("Starting writing the words.txt")
|
||||||
|
f_out = open(output_file, "w", encoding="utf-8")
|
||||||
|
for line in new_lines:
|
||||||
|
f_out.write(line)
|
||||||
|
f_out.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
83
egs/wenetspeech/ASR/local/text2segments.py
Normal file
83
egs/wenetspeech/ASR/local/text2segments.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input "text", which refers to the transcript file for
|
||||||
|
WenetSpeech:
|
||||||
|
- text
|
||||||
|
and generates the output file text_word_segmentation which is implemented
|
||||||
|
with word segmenting:
|
||||||
|
- text_words_segmentation
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import jieba
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
jieba.enable_paddle()
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Chinese Word Segmentation for text",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input",
|
||||||
|
default="data/lang_char/text",
|
||||||
|
type=str,
|
||||||
|
help="the input text file for WenetSpeech",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
default="data/lang_char/text_words_segmentation",
|
||||||
|
type=str,
|
||||||
|
help="the text implemented with words segmenting for WenetSpeech",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
input_file = args.input
|
||||||
|
output_file = args.output
|
||||||
|
|
||||||
|
f = open(input_file, "r", encoding="utf-8")
|
||||||
|
lines = f.readlines()
|
||||||
|
new_lines = []
|
||||||
|
for i in tqdm(range(len(lines))):
|
||||||
|
x = lines[i].rstrip()
|
||||||
|
seg_list = jieba.cut(x, use_paddle=True)
|
||||||
|
new_line = " ".join(seg_list)
|
||||||
|
new_lines.append(new_line)
|
||||||
|
|
||||||
|
f_new = open(output_file, "w", encoding="utf-8")
|
||||||
|
for line in new_lines:
|
||||||
|
f_new.write(line)
|
||||||
|
f_new.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -188,7 +188,7 @@ def main():
|
|||||||
|
|
||||||
a_chars = [z.replace(" ", args.space) for z in a_flat]
|
a_chars = [z.replace(" ", args.space) for z in a_flat]
|
||||||
|
|
||||||
print(" ".join(a_chars))
|
print("".join(a_chars))
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,9 +117,9 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
log "Stage 7: Combine features for L"
|
log "Stage 7: Combine features for L"
|
||||||
if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then
|
if [ ! -f data/fbank/cuts_L_50.jsonl.gz ]; then
|
||||||
pieces=$(find data/fbank/L_split_${num_splits} -name "cuts_L.*.jsonl.gz")
|
pieces=$(find data/fbank/L_split_50 -name "cuts_L.*.jsonl.gz")
|
||||||
lhotse combine $pieces data/fbank/cuts_L.jsonl.gz
|
lhotse combine $pieces data/fbank/cuts_L_50.jsonl.gz
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -134,120 +134,50 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
|||||||
lang_char_dir=data/lang_char
|
lang_char_dir=data/lang_char
|
||||||
mkdir -p $lang_char_dir
|
mkdir -p $lang_char_dir
|
||||||
|
|
||||||
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
# Prepare text.
|
||||||
| jq '.text' | sed 's/"//g' \
|
if [ ! -f $lang_char_dir/text ]; then
|
||||||
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
||||||
|
| jq '.text' | sed 's/"//g' \
|
||||||
|
| ./local/text2token.py -t "char" > $lang_char_dir/text
|
||||||
|
fi
|
||||||
|
|
||||||
cat $lang_char_dir/text | sed 's/ /\n/g' \
|
# The implementation of chinese word segmentation for text,
|
||||||
| sort -u | sed '/^$/d' > $lang_char_dir/words.txt
|
# and it will take about 15 minutes.
|
||||||
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
|
if [ ! -f $lang_char_dir/text_words_segmentation ]; then
|
||||||
cat - $lang_char_dir/words.txt | sort | uniq | awk '
|
python ./local/text2segments.py \
|
||||||
BEGIN {
|
--input $lang_char_dir/text \
|
||||||
print "<eps> 0";
|
--output $lang_char_dir/text_words_segmentation
|
||||||
}
|
fi
|
||||||
{
|
|
||||||
if ($1 == "<s>") {
|
|
||||||
print "<s> is in the vocabulary!" | "cat 1>&2"
|
|
||||||
exit 1;
|
|
||||||
}
|
|
||||||
if ($1 == "</s>") {
|
|
||||||
print "</s> is in the vocabulary!" | "cat 1>&2"
|
|
||||||
exit 1;
|
|
||||||
}
|
|
||||||
printf("%s %d\n", $1, NR);
|
|
||||||
}
|
|
||||||
END {
|
|
||||||
printf("#0 %d\n", NR+1);
|
|
||||||
printf("<s> %d\n", NR+2);
|
|
||||||
printf("</s> %d\n", NR+3);
|
|
||||||
}' > $lang_char_dir/words || exit 1;
|
|
||||||
|
|
||||||
mv $lang_char_dir/words $lang_char_dir/words.txt
|
cat $lang_char_dir/text_words_segmentation | sed 's/ /\n/g' \
|
||||||
|
| sort -u | sed '/^$/d' | uniq > $lang_char_dir/words_no_ids.txt
|
||||||
|
|
||||||
|
if [ ! -f $lang_char_dir/words.txt ]; then
|
||||||
|
python ./local/prepare_words.py \
|
||||||
|
--input-file $lang_char_dir/words_no_ids.txt \
|
||||||
|
--output-file $lang_char_dir/words.txt
|
||||||
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
|
||||||
log "Stage 10: Prepare pinyin based lang"
|
log "Stage 10: Prepare char based L_disambig.pt"
|
||||||
lang_pinyin_dir=data/lang_pinyin
|
if [ ! -f data/lang_char/L_disambig.pt ]; then
|
||||||
mkdir -p $lang_pinyin_dir
|
python ./local/prepare_char.py \
|
||||||
|
--lang-dir data/lang_char
|
||||||
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
fi
|
||||||
| jq '.text' | sed 's/"//g' \
|
|
||||||
| ./local/text2token.py -t "pinyin" > $lang_pinyin_dir/text
|
|
||||||
|
|
||||||
cat $lang_pinyin_dir/text | sed 's/ /\n/g' \
|
|
||||||
| sort -u | sed '/^$/d' > $lang_pinyin_dir/words.txt
|
|
||||||
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
|
|
||||||
cat - $lang_pinyin_dir/words.txt | sort | uniq | awk '
|
|
||||||
BEGIN {
|
|
||||||
print "<eps> 0";
|
|
||||||
}
|
|
||||||
{
|
|
||||||
if ($1 == "<s>") {
|
|
||||||
print "<s> is in the vocabulary!" | "cat 1>&2"
|
|
||||||
exit 1;
|
|
||||||
}
|
|
||||||
if ($1 == "</s>") {
|
|
||||||
print "</s> is in the vocabulary!" | "cat 1>&2"
|
|
||||||
exit 1;
|
|
||||||
}
|
|
||||||
printf("%s %d\n", $1, NR);
|
|
||||||
}
|
|
||||||
END {
|
|
||||||
printf("#0 %d\n", NR+1);
|
|
||||||
printf("<s> %d\n", NR+2);
|
|
||||||
printf("</s> %d\n", NR+3);
|
|
||||||
}' > $lang_pinyin_dir/words || exit 1;
|
|
||||||
|
|
||||||
mv $lang_pinyin_dir/words $lang_pinyin_dir/words.txt
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||||
log "Stage 11: Prepare lazy_pinyin based lang"
|
log "Stage 11: Prepare pinyin based L_disambig.pt"
|
||||||
lang_lazy_pinyin_dir=data/lang_lazy_pinyin
|
lang_pinyin_dir=data/lang_pinyin
|
||||||
mkdir -p $lang_lazy_pinyin_dir
|
mkdir -p $lang_pinyin_dir
|
||||||
|
|
||||||
gunzip -c data/manifests/supervisions_L.jsonl.gz \
|
cp -r data/lang_char/words.txt $lang_pinyin_dir/
|
||||||
| jq '.text' | sed 's/"//g' \
|
cp -r data/lang_char/text $lang_pinyin_dir/
|
||||||
| ./local/text2token.py -t "lazy_pinyin" > $lang_lazy_pinyin_dir/text
|
cp -r data/lang_char/text_words_segmentation $lang_pinyin_dir/
|
||||||
|
|
||||||
cat $lang_lazy_pinyin_dir/text | sed 's/ /\n/g' \
|
|
||||||
| sort -u | sed '/^$/d' > $lang_lazy_pinyin_dir/words.txt
|
|
||||||
(echo '<SIL>'; echo '<SPOKEN_NOISE>'; echo '<UNK>'; ) |
|
|
||||||
cat - $lang_lazy_pinyin_dir/words.txt | sort | uniq | awk '
|
|
||||||
BEGIN {
|
|
||||||
print "<eps> 0";
|
|
||||||
}
|
|
||||||
{
|
|
||||||
if ($1 == "<s>") {
|
|
||||||
print "<s> is in the vocabulary!" | "cat 1>&2"
|
|
||||||
exit 1;
|
|
||||||
}
|
|
||||||
if ($1 == "</s>") {
|
|
||||||
print "</s> is in the vocabulary!" | "cat 1>&2"
|
|
||||||
exit 1;
|
|
||||||
}
|
|
||||||
printf("%s %d\n", $1, NR);
|
|
||||||
}
|
|
||||||
END {
|
|
||||||
printf("#0 %d\n", NR+1);
|
|
||||||
printf("<s> %d\n", NR+2);
|
|
||||||
printf("</s> %d\n", NR+3);
|
|
||||||
}' > $lang_lazy_pinyin_dir/words || exit 1;
|
|
||||||
|
|
||||||
mv $lang_lazy_pinyin_dir/words $lang_lazy_pinyin_dir/words.txt
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
|
||||||
log "Stage 12: Prepare L_disambig.pt"
|
|
||||||
if [ ! -f data/lang_char/L_disambig.pt ]; then
|
|
||||||
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_char
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ ! -f data/lang_pinyin/L_disambig.pt ]; then
|
if [ ! -f data/lang_pinyin/L_disambig.pt ]; then
|
||||||
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_pinyin
|
python ./local/prepare_pinyin.py \
|
||||||
fi
|
--lang-dir data/lang_pinyin
|
||||||
|
|
||||||
if [ ! -f data/lang_lazy_pinyin/L_disambig.pt ]; then
|
|
||||||
python ./local/prepare_lang_wenetspeech.py --lang-dir data/lang_lazy_pinyin
|
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -375,11 +375,15 @@ class WenetSpeechAsrDataModule:
|
|||||||
if self.args.lazy_load:
|
if self.args.lazy_load:
|
||||||
logging.info("use lazy cuts")
|
logging.info("use lazy cuts")
|
||||||
cuts_train = CutSet.from_jsonl_lazy(
|
cuts_train = CutSet.from_jsonl_lazy(
|
||||||
self.args.manifest_dir / "cuts_L.jsonl.gz"
|
self.args.manifest_dir
|
||||||
|
/ "cuts_L_50_pieces.jsonl.gz"
|
||||||
|
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cuts_train = CutSet.from_file(
|
cuts_train = CutSet.from_file(
|
||||||
self.args.manifest_dir / "cuts_L.jsonl.gz"
|
self.args.manifest_dir
|
||||||
|
/ "cuts_L_50_pieces.jsonl.gz"
|
||||||
|
# use cuts_L_50_pieces.jsonl.gz for original experiments
|
||||||
)
|
)
|
||||||
return cuts_train
|
return cuts_train
|
||||||
|
|
||||||
|
@ -25,8 +25,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir pruned_transducer_stateless/exp \
|
--lang-dir data/lang_char \
|
||||||
--max-duration 300
|
--exp-dir pruned_transducer_stateless/exp-char \
|
||||||
|
--token-type char \
|
||||||
|
--max-duration 200
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -44,8 +46,8 @@ from asr_datamodule import WenetSpeechAsrDataModule
|
|||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from local.text2token import token2id
|
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -59,6 +61,7 @@ from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
|||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.pinyin_graph_compiler import PinyinCtcTrainingGraphCompiler
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
@ -108,7 +111,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless/exp",
|
default="pruned_transducer_stateless_pinyin/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -118,7 +121,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_lazy_pinyin",
|
default="data/lang_char",
|
||||||
help="""The lang dir
|
help="""The lang dir
|
||||||
It contains language related input files such as
|
It contains language related input files such as
|
||||||
"lexicon.txt"
|
"lexicon.txt"
|
||||||
@ -128,10 +131,9 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--token-type",
|
"--token-type",
|
||||||
type=str,
|
type=str,
|
||||||
default="lazy_pinyin",
|
default="char",
|
||||||
help="""The token type
|
help="""The type of token
|
||||||
It refers to the token type for modeling, such as
|
It must be in ["char", "pinyin", "lazy_pinyin"].
|
||||||
char, pinyin, lazy_pinyin.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -435,16 +437,12 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = ""
|
|
||||||
if params.token_type == "char":
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
y = graph_compiler.texts_to_ids(texts)
|
if type(y) == list:
|
||||||
|
y = k2.RaggedTensor(y).to(device)
|
||||||
else:
|
else:
|
||||||
y = token2id(
|
y = y.to(device)
|
||||||
texts=texts,
|
|
||||||
token_table=graph_compiler.token_table,
|
|
||||||
token_type=params.token_type,
|
|
||||||
)
|
|
||||||
y = k2.RaggedTensor(y).to(device)
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss = model(
|
||||||
@ -635,9 +633,19 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
graph_compiler = ""
|
||||||
lexicon=lexicon, device=device, oov="<unk>"
|
if params.token_type == "char":
|
||||||
)
|
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if params.token_type == "pinyin":
|
||||||
|
graph_compiler = PinyinCtcTrainingGraphCompiler(
|
||||||
|
lang_dir=params.lang_dir,
|
||||||
|
lexicon=lexicon,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
@ -672,6 +680,23 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_cuts = wenetspeech.train_cuts()
|
train_cuts = wenetspeech.train_cuts()
|
||||||
|
|
||||||
|
def remove_short_and_long_utt(c: Cut):
|
||||||
|
# Keep only utterances with duration between 1 second and 15.0 seconds
|
||||||
|
# You can get the statistics by local/display_manifest_statistics.py.
|
||||||
|
return 1.0 <= c.duration <= 15.0
|
||||||
|
|
||||||
|
def text_to_words(c: Cut):
|
||||||
|
# Convert text to words_segments.
|
||||||
|
text = c.supervisions[0].text
|
||||||
|
text = text.strip("\n").strip("\t")
|
||||||
|
words_cut = graph_compiler.text2words[text]
|
||||||
|
c.supervisions[0].text = words_cut
|
||||||
|
return c
|
||||||
|
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
if params.token_type == "pinyin":
|
||||||
|
train_cuts = train_cuts.map(text_to_words)
|
||||||
|
|
||||||
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
||||||
valid_cuts = wenetspeech.valid_cuts()
|
valid_cuts = wenetspeech.valid_cuts()
|
||||||
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
|
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
|
||||||
|
219
icefall/pinyin_graph_compiler.py
Normal file
219
icefall/pinyin_graph_compiler.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon, read_lexicon
|
||||||
|
|
||||||
|
|
||||||
|
class PinyinCtcTrainingGraphCompiler(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lang_dir: Path,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
device: torch.device,
|
||||||
|
sos_token: str = "<sos/eos>",
|
||||||
|
eos_token: str = "<sos/eos>",
|
||||||
|
oov: str = "<unk>",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is built from `data/lang_char/lexicon.txt`.
|
||||||
|
device:
|
||||||
|
The device to use for operations compiling transcripts to FSAs.
|
||||||
|
oov:
|
||||||
|
Out of vocabulary token. When a word(token) in the transcript
|
||||||
|
does not exist in the token list, it is replaced with `oov`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert oov in lexicon.token_table
|
||||||
|
|
||||||
|
self.lang_dir = lang_dir
|
||||||
|
self.oov_id = lexicon.token_table[oov]
|
||||||
|
self.token_table = lexicon.token_table
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.sos_id = self.token_table[sos_token]
|
||||||
|
self.eos_id = self.token_table[eos_token]
|
||||||
|
|
||||||
|
self.word_table = lexicon.word_table
|
||||||
|
self.token_table = lexicon.token_table
|
||||||
|
|
||||||
|
self.text2words = convert_text_to_word_segments(
|
||||||
|
text_filename=self.lang_dir / "text",
|
||||||
|
words_segments_filename=self.lang_dir / "text_words_segmentation",
|
||||||
|
)
|
||||||
|
self.ragged_lexicon = convert_lexicon_to_ragged(
|
||||||
|
filename=self.lang_dir / "lexicon.txt",
|
||||||
|
word_table=self.word_table,
|
||||||
|
token_table=self.token_table,
|
||||||
|
)
|
||||||
|
|
||||||
|
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||||
|
"""Convert a list of texts to a list-of-list of pinyin-based token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
It is a list of strings.
|
||||||
|
An example containing two strings is given below:
|
||||||
|
|
||||||
|
['你好中国', '北京欢迎您']
|
||||||
|
Returns:
|
||||||
|
Return a list-of-list of pinyin-based token IDs.
|
||||||
|
"""
|
||||||
|
word_ids_list = []
|
||||||
|
for i in range(len(texts)):
|
||||||
|
word_ids = []
|
||||||
|
text = texts[i].strip("\n").strip("\t")
|
||||||
|
for word in text.split(" "):
|
||||||
|
if word in self.word_table:
|
||||||
|
word_ids.append(self.word_table[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(self.oov_id)
|
||||||
|
word_ids_list.append(word_ids)
|
||||||
|
|
||||||
|
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
|
||||||
|
ans = self.ragged_lexicon.index(ragged_indexes)
|
||||||
|
ans = ans.remove_axis(ans.num_axes - 2)
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def compile(
|
||||||
|
self,
|
||||||
|
token_ids: List[List[int]],
|
||||||
|
modified: bool = False,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""Build a ctc graph from a list-of-list token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
piece_ids:
|
||||||
|
It is a list-of-list integer IDs.
|
||||||
|
modified:
|
||||||
|
See :func:`k2.ctc_graph` for its meaning.
|
||||||
|
Return:
|
||||||
|
Return an FsaVec, which is the result of composing a
|
||||||
|
CTC topology with linear FSAs constructed from the given
|
||||||
|
piece IDs.
|
||||||
|
"""
|
||||||
|
return k2.ctc_graph(token_ids, modified=modified, device=self.device)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_lexicon_to_ragged(
|
||||||
|
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
|
||||||
|
) -> k2.RaggedTensor:
|
||||||
|
"""Read a lexicon and convert lexicon to a ragged tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Path to the lexicon file.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
token_table:
|
||||||
|
The token symbol table.
|
||||||
|
Returns:
|
||||||
|
A k2 ragged tensor with two axes [word][token].
|
||||||
|
"""
|
||||||
|
num_words = len(word_table.symbols)
|
||||||
|
excluded_words = [
|
||||||
|
"<eps>",
|
||||||
|
"!SIL",
|
||||||
|
"<SPOKEN_NOISE>",
|
||||||
|
"<UNK>",
|
||||||
|
"#0",
|
||||||
|
"<s>",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
|
||||||
|
row_splits = [0]
|
||||||
|
token_ids_list = []
|
||||||
|
|
||||||
|
lexicon_tmp = read_lexicon(filename)
|
||||||
|
lexicon = dict(lexicon_tmp)
|
||||||
|
if len(lexicon_tmp) != len(lexicon):
|
||||||
|
raise RuntimeError(
|
||||||
|
"It's assumed that each word has a unique pronunciation"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(num_words):
|
||||||
|
w = word_table[i]
|
||||||
|
if w in excluded_words:
|
||||||
|
row_splits.append(row_splits[-1])
|
||||||
|
continue
|
||||||
|
tokens = lexicon[w]
|
||||||
|
token_ids = [token_table[k] for k in tokens]
|
||||||
|
|
||||||
|
row_splits.append(row_splits[-1] + len(token_ids))
|
||||||
|
token_ids_list.extend(token_ids)
|
||||||
|
|
||||||
|
cached_tot_size = row_splits[-1]
|
||||||
|
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||||
|
|
||||||
|
shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits,
|
||||||
|
None,
|
||||||
|
cached_tot_size,
|
||||||
|
)
|
||||||
|
values = torch.tensor(token_ids_list, dtype=torch.int32)
|
||||||
|
|
||||||
|
return k2.RaggedTensor(shape, values)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_text_to_word_segments(
|
||||||
|
text_filename: str, words_segments_filename: str
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Convert text to word-based segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_filename:
|
||||||
|
The file for the original transcripts.
|
||||||
|
words_segments_filename:
|
||||||
|
The file after implementing chinese word segmentation
|
||||||
|
for the original transcripts.
|
||||||
|
Returns:
|
||||||
|
A dictionary about text and words_segments.
|
||||||
|
"""
|
||||||
|
text2words = {}
|
||||||
|
|
||||||
|
f_text = open(text_filename, "r", encoding="utf-8")
|
||||||
|
text_lines = f_text.readlines()
|
||||||
|
text_lines = [line.strip("\t") for line in text_lines]
|
||||||
|
|
||||||
|
f_words = open(words_segments_filename, "r", encoding="utf-8")
|
||||||
|
words_lines = f_words.readlines()
|
||||||
|
words_lines = [line.strip("\t") for line in words_lines]
|
||||||
|
|
||||||
|
if len(text_lines) != len(words_lines):
|
||||||
|
raise RuntimeError(
|
||||||
|
"The lengths of text and words_segments should be equal."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in tqdm(range(len(text_lines))):
|
||||||
|
text = text_lines[i].strip(" ").strip("\n")
|
||||||
|
words_segments = words_lines[i].strip(" ").strip("\n")
|
||||||
|
text2words[text] = words_segments
|
||||||
|
|
||||||
|
return text2words
|
Loading…
x
Reference in New Issue
Block a user