Minor fixes for the commonvoice recipe (#1534)

* init commit

* fix for issue https://github.com/k2-fsa/icefall/issues/1531

* minor fixes
This commit is contained in:
zr_jin 2024-03-08 11:01:11 +08:00 committed by GitHub
parent 5df24c1685
commit ae61bd4090
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 344 additions and 14 deletions

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/compile_hlg.py

View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FSA representing HLG.
"""
lexicon = Lexicon(lang_dir)
max_token_id = max(lexicon.tokens)
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
logging.info("Composing H and LG")
# CAUTION: The name of the inner_labels is fixed
# to `tokens`. If you want to change it, please
# also change other places in icefall that are using
# it.
HLG = k2.compose(H, LG, inner_labels="tokens")
logging.info("Connecting LG")
HLG = k2.connect(HLG)
logging.info("Arc sorting LG")
HLG = k2.arc_sort(HLG)
logging.info(f"HLG.shape: {HLG.shape}")
return HLG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "HLG.pt").is_file():
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/compile_lg.py

View File

@ -0,0 +1,149 @@
#!/usr/bin/env python3
# Copyright 2021-2024 Xiaomi Corp. (authors: Fangjun Kuang,
# Kang Wei,
# Zengrui Jin,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input lang_dir and generates LG from
- L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from lang_dir/lm/G_3_gram.fst.txt
The generated LG is saved in $lang_dir/LG.pt
"""
import argparse
import logging
from pathlib import Path
import k2
import torch
from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
return parser.parse_args()
def compile_LG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return:
An FSA representing LG.
"""
lexicon = Lexicon(lang_dir)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path(f"{lang_dir}/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"{lang_dir}/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info(f"Loading {lm}.fst.txt")
with open(f"{lang_dir}/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), f"{lang_dir}/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
logging.info("Intersecting L and G")
LG = k2.compose(L, G)
logging.info(f"LG shape: {LG.shape}")
logging.info("Connecting LG")
LG = k2.connect(LG)
logging.info(f"LG shape after k2.connect: {LG.shape}")
logging.info(type(LG.aux_labels))
logging.info("Determinizing LG")
LG = k2.determinize(LG, k2.DeterminizeWeightPushingType.kLogWeightPushing)
logging.info(type(LG.aux_labels))
logging.info("Connecting LG after k2.determinize")
LG = k2.connect(LG)
logging.info("Removing disambiguation symbols on LG")
# LG.labels[LG.labels >= first_token_disambig_id] = 0
# see https://github.com/k2-fsa/k2/pull/1140
labels = LG.labels
labels[labels >= first_token_disambig_id] = 0
LG.labels = labels
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
logging.info("Arc sorting LG")
LG = k2.arc_sort(LG)
return LG
def main():
args = get_args()
lang_dir = Path(args.lang_dir)
if (lang_dir / "LG.pt").is_file():
logging.info(f"{lang_dir}/LG.pt already exists - skipping")
return
logging.info(f"Processing {lang_dir}")
LG = compile_LG(lang_dir, args.lm)
logging.info(f"Saving LG.pt to {lang_dir}")
torch.save(LG.as_dict(), f"{lang_dir}/LG.pt")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -52,6 +52,15 @@ def normalize_text(utt: str, language: str) -> str:
return re.sub(r"[^A-ZÀÂÆÇÉÈÊËÎÏÔŒÙÛÜ' ]", "", utt).upper()
elif language == "pl":
return re.sub(r"[^a-ząćęłńóśźżA-ZĄĆĘŁŃÓŚŹŻ' ]", "", utt).upper()
elif language == "yue":
return (
utt.replace(" ", "")
.replace("", "")
.replace("", " ")
.replace("", "")
.replace("", "")
.replace("?", "")
)
else:
raise NotImplementedError(
f"""

View File

@ -381,9 +381,11 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(

View File

@ -31,7 +31,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
@ -315,8 +315,8 @@ class CommonVoiceAsrDataModule:
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
@ -383,9 +383,11 @@ class CommonVoiceAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(

View File

@ -425,9 +425,11 @@ class LibriHeavyAsrDataModule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(