mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Update timit recipe
This commit is contained in:
parent
e023a9df98
commit
4beb25c50b
@ -54,7 +54,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
The language directory, e.g., data/lang_phone.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
@ -63,18 +63,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
|
||||
if Path(lang_dir / "L_disambig.pt").is_file():
|
||||
logging.info("Loading L_disambig.pt")
|
||||
d = torch.load(Path(lang_dir/"L_disambig.pt"))
|
||||
L = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading L_disambig.fst.txt")
|
||||
with open(Path(lang_dir/"L_disambig.fst.txt")) as f:
|
||||
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(L_disambig.as_dict(), Path(lang_dir / "L_disambig.pt"))
|
||||
|
||||
#L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G.pt").is_file():
|
||||
logging.info("Loading pre-compiled G")
|
||||
|
@ -106,7 +106,7 @@ def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||
ans = set()
|
||||
for _, tokens in lexicon:
|
||||
ans.update(tokens)
|
||||
#sorted_ans = sorted(list(ans))
|
||||
|
||||
sorted_ans = list(ans)
|
||||
return sorted_ans
|
||||
|
||||
@ -275,18 +275,11 @@ def lexicon_to_fst(
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
print('token2id ori: ', token2id)
|
||||
print('word2id ori: ', word2id)
|
||||
|
||||
assert token2id["<eps>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
print('token2id new: ', token2id)
|
||||
print('word2id new: ', word2id)
|
||||
|
||||
print(lexicon)
|
||||
for word, tokens in lexicon:
|
||||
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
@ -306,7 +299,7 @@ def lexicon_to_fst(
|
||||
# the other one to the sil_state.
|
||||
i = len(tokens) - 1
|
||||
w = word if i == 0 else eps
|
||||
tokens[i] = tokens[i] if i >=0 else eps
|
||||
tokens[i] = tokens[i] if i >= 0 else eps
|
||||
arcs.append([cur_state, loop_state, tokens[i], w, score])
|
||||
|
||||
if need_self_loops:
|
||||
@ -326,7 +319,6 @@ def lexicon_to_fst(
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
print(arcs)
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
@ -334,9 +326,8 @@ def lexicon_to_fst(
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
#out_dir = Path("data/lang_phone")
|
||||
lexicon_filename = lang_dir / "lexicon.txt"
|
||||
|
||||
|
||||
lexicon = read_lexicon(lexicon_filename)
|
||||
tokens = get_tokens(lexicon)
|
||||
|
||||
@ -386,7 +377,7 @@ def main():
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
L_disambig.labels_sym = L.labels_sym
|
||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||
L.draw(out_dir / "L.png", title="L")
|
||||
L.draw(lang_dir / "L.png", title="L")
|
||||
L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig")
|
||||
|
||||
|
||||
|
@ -59,48 +59,43 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
|
||||
The lexicon.txt file and the train.text in lang_dir.
|
||||
"""
|
||||
phones = []
|
||||
|
||||
|
||||
supervisions_train = Path(manifests_dir) / "supervisions_TRAIN.json"
|
||||
lexicon = Path(lang_dir) / "lexicon.txt"
|
||||
|
||||
|
||||
logging.info(f"Loading {supervisions_train}!")
|
||||
with open(supervisions_train, 'r') as load_f:
|
||||
with open(supervisions_train, "r") as load_f:
|
||||
load_dicts = json.load(load_f)
|
||||
for load_dict in load_dicts:
|
||||
idx = load_dict['id']
|
||||
text = load_dict['text']
|
||||
phones_list = list(filter(None, text.split(' ')))
|
||||
text = load_dict["text"]
|
||||
phones_list = list(filter(None, text.split(" ")))
|
||||
|
||||
for phone in phones_list:
|
||||
if phone not in phones:
|
||||
phones.append(phone)
|
||||
|
||||
with open(lexicon, 'w') as f:
|
||||
|
||||
with open(lexicon, "w") as f:
|
||||
for phone in sorted(phones):
|
||||
f.write(str(phone) + " " + str(phone))
|
||||
f.write('\n')
|
||||
f.write("\n")
|
||||
f.write("<UNK> <UNK>")
|
||||
f.write('\n')
|
||||
f.write("\n")
|
||||
|
||||
return lexicon
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
manifests_dir = Path(args.manifests_dir)
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
logging.info(f"Generating lexicon.txt and train.text")
|
||||
logging.info("Generating lexicon.txt")
|
||||
prepare_lexicon(manifests_dir, lang_dir)
|
||||
|
||||
lexicon_file = prepare_lexicon(manifests_dir, lang_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
||||
|
||||
|
@ -311,26 +311,20 @@ class TimitAsrDataModule(DataModule):
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(
|
||||
self.args.feature_dir / "cuts_TRAIN.json.gz"
|
||||
)
|
||||
cuts_train = load_manifest(self.args.feature_dir / "cuts_TRAIN.json.gz")
|
||||
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(
|
||||
self.args.feature_dir / "cuts_DEV.json.gz"
|
||||
)
|
||||
cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz")
|
||||
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.debug("About to get test cuts")
|
||||
cuts_test = load_manifest(
|
||||
self.args.feature_dir / "cuts_TEST.json.gz"
|
||||
)
|
||||
|
||||
cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz")
|
||||
|
||||
return cuts_test
|
||||
|
@ -310,7 +310,7 @@ def decode_dataset(
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
@ -449,7 +449,6 @@ def main():
|
||||
)
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
#load_checkpoint(f"tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretrained.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
@ -470,18 +469,6 @@ def main():
|
||||
model.eval()
|
||||
|
||||
timit = TimitAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
# it inside the for loop. That is, use
|
||||
#
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
#test_sets = ["test-clean", "test-other"]
|
||||
#test_sets = ["test-other"]
|
||||
#for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
#if test_set == "test-clean": continue
|
||||
#if test_set == "test-other": break
|
||||
test_set = "TEST"
|
||||
test_dl = timit.test_dataloaders()
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
@ -491,7 +478,7 @@ def main():
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
)
|
||||
|
||||
test_set = "TEST"
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user