mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
Minor fixes to CharCtcGraphCompiler
This commit is contained in:
parent
724e387c6f
commit
f2f4087778
@ -602,11 +602,9 @@ def compute_loss(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = graph_compiler.texts_to_ids_with_bpe(texts)
|
||||
if type(y) == list:
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
else:
|
||||
y = y.to(device)
|
||||
y = graph_compiler.texts_to_ids(texts, sep="/")
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
|
141
egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py
Normal file
141
egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import lhotse
|
||||
from pathlib import Path
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, fix_manifests, validate_recordings_and_supervisions
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--kaldi-dir",
|
||||
type=str,
|
||||
help="""The directory containing kaldi style manifest, namely wav.scp, text and segments.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-mel-bins",
|
||||
type=int,
|
||||
default=80,
|
||||
help="""The number of mel bank bins.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="""The directory where the lhotse manifests and features to write to.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
help="""The name of dataset.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
help="""Could be something like train, valid, test and so on.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--perturb-speed",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""Perturb speed with factor 0.9 and 1.1 on train subset.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-jobs",
|
||||
type=int,
|
||||
default=50,
|
||||
help="The num of jobs to extract feature."
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_cuts(args):
|
||||
logging.info(f"Prepare cuts from {args.kaldi_dir}.")
|
||||
recordings, supervisions, _ = lhotse.load_kaldi_data_dir(args.kaldi_dir, 16000)
|
||||
recordings, supervisions = fix_manifests(recordings, supervisions)
|
||||
validate_recordings_and_supervisions(recordings, supervisions)
|
||||
cuts = CutSet.from_manifests(recordings=recordings, supervisions=supervisions)
|
||||
return cuts
|
||||
|
||||
|
||||
def compute_feature(args, cuts):
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=args.num_mel_bins))
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
cuts_filename = f"{args.dataset}_cuts_{args.partition}.jsonl.gz"
|
||||
if (args.output_dir / cuts_filename).is_file():
|
||||
logging.info(f"{cuts_filename} already exists - skipping.")
|
||||
return
|
||||
logging.info(f"Processing {cuts_filename}")
|
||||
|
||||
if "train" in args.partition:
|
||||
if args.perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
cuts = (
|
||||
cuts
|
||||
+ cuts.perturb_speed(0.9)
|
||||
+ cuts.perturb_speed(1.1)
|
||||
)
|
||||
cuts = cuts.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}",
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=args.num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
cuts.to_file(args.output_dir / cuts_filename)
|
||||
|
||||
|
||||
def main(args):
|
||||
args.kaldi_dir = Path(args.kaldi_dir)
|
||||
args.output_dir = Path(args.output_dir)
|
||||
cuts = prepare_cuts(args)
|
||||
compute_feature(args, cuts)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
main(args)
|
@ -54,7 +54,7 @@ class CharCtcTrainingGraphCompiler(object):
|
||||
self.sos_id = self.token_table[sos_token]
|
||||
self.eos_id = self.token_table[eos_token]
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
def texts_to_ids(self, texts: List[str], sep: str = "") -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of token IDs.
|
||||
|
||||
Args:
|
||||
@ -63,36 +63,21 @@ class CharCtcTrainingGraphCompiler(object):
|
||||
An example containing two strings is given below:
|
||||
|
||||
['你好中国', '北京欢迎您']
|
||||
sep:
|
||||
The separator of the items in one sequence, mainly no separator for
|
||||
Chinese (one character a token), "/" for Chinese characters plus BPE
|
||||
token and pinyin tokens.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs.
|
||||
"""
|
||||
assert sep in ("", "/"), sep
|
||||
ids: List[List[int]] = []
|
||||
whitespace = re.compile(r"([ \t])")
|
||||
for text in texts:
|
||||
text = re.sub(whitespace, "", text)
|
||||
sub_ids = [
|
||||
self.token_table[txt] if txt in self.token_table else self.oov_id
|
||||
for txt in text
|
||||
]
|
||||
ids.append(sub_ids)
|
||||
return ids
|
||||
|
||||
def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts (which include chars and bpes)
|
||||
to a list-of-list of token IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings.
|
||||
An example containing two strings is given below:
|
||||
|
||||
[['你', '好', '▁C', 'hina'], ['北','京', '▁', 'welcome', '您']
|
||||
Returns:
|
||||
Return a list-of-list of token IDs.
|
||||
"""
|
||||
ids: List[List[int]] = []
|
||||
for text in texts:
|
||||
text = text.split("/")
|
||||
if sep == "":
|
||||
text = re.sub(whitespace, "", text)
|
||||
else:
|
||||
text = text.split(sep)
|
||||
sub_ids = [
|
||||
self.token_table[txt] if txt in self.token_table else self.oov_id
|
||||
for txt in text
|
||||
|
Loading…
x
Reference in New Issue
Block a user