mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 11:32:19 +00:00
add extract speech tokens
update prepare.sh update update add attach_speech_tokens
This commit is contained in:
parent
390695bcf3
commit
f90c3ae3ec
81
egs/libriheavy/TTS/local/attach_speech_tokens.py
Executable file
81
egs/libriheavy/TTS/local/attach_speech_tokens.py
Executable file
@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
default="small",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def attach_speech_tokens(args):
|
||||
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
|
||||
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = "data/tokens"
|
||||
output_dir = Path(output_dir)
|
||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||
|
||||
prefix = "libriheavy"
|
||||
|
||||
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
return
|
||||
|
||||
manifests_path = src_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||
assert manifests_path.is_file(), f"{manifests_path} does not exist!"
|
||||
|
||||
tokens_path = output_dir / f"{prefix}_{args.subset}.jsonl.gz"
|
||||
assert tokens_path.is_file(), f"{tokens_path} does not exist!"
|
||||
|
||||
id2tokens = {}
|
||||
with gzip.open(tokens_path, "r") as fin:
|
||||
for line in fin:
|
||||
line = json.loads(line)
|
||||
id2tokens[line["key"]] = " ".join(map(str, line["code"]))
|
||||
|
||||
with gzip.open(manifests_path, "r") as fin, gzip.open(cuts_path, "w") as fout:
|
||||
for cut in tqdm(fin, desc="Processing"):
|
||||
cut = json.loads(cut)
|
||||
if cut["id"] in id2tokens:
|
||||
cut["custom"] = {"tokens": id2tokens[cut["id"]]}
|
||||
fout.write((json.dumps(cut) + "\n").encode())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
attach_speech_tokens(args)
|
@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, SupervisionSegment
|
||||
from lhotse.utils import fastcopy
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--subset",
|
||||
type=str,
|
||||
default="small",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="download/hubert_base_ls960.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Process pieces starting from this number (inclusive).",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--stop",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Stop processing pieces until this number (exclusive).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def extract_and_save_one_cuts(
|
||||
manifests_path,
|
||||
cuts_path,
|
||||
):
|
||||
logging.info(f"Loading {manifests_path}")
|
||||
cut_set = CutSet.from_file(manifests_path)
|
||||
|
||||
logging.info("Extracting tokens")
|
||||
cuts = []
|
||||
|
||||
tokens = " ".join(map(str, tokens))
|
||||
|
||||
cut_with_tokens = fastcopy(
|
||||
cut,
|
||||
custom={"tokens": tokens},
|
||||
)
|
||||
cuts.append(cut_with_tokens)
|
||||
|
||||
cuts = CutSet(cuts)
|
||||
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cuts.to_file(cuts_path)
|
||||
|
||||
|
||||
def extract_speech_tokens(args):
|
||||
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
|
||||
|
||||
output_dir = (
|
||||
f"data/tokens/{args.subset}_split" if args.subset != "small" else "data/tokens"
|
||||
)
|
||||
output_dir = Path(output_dir)
|
||||
assert output_dir.exists(), f"{output_dir} does not exist!"
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
prefix = "libriheavy"
|
||||
|
||||
if args.subset == "small":
|
||||
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
return
|
||||
|
||||
manifests_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
|
||||
if not manifests_path.is_file():
|
||||
logging.info(f"{manifests_path} does not exist - skipping it")
|
||||
return
|
||||
|
||||
extract_and_save_one_cuts(
|
||||
manifests_path,
|
||||
cuts_path,
|
||||
model,
|
||||
apply_tokens,
|
||||
do_normalize,
|
||||
window_duration,
|
||||
shift_duration,
|
||||
)
|
||||
else:
|
||||
num_digits = 8 # num_digits is fixed by lhotse split-lazy
|
||||
start = args.start
|
||||
stop = args.stop
|
||||
assert stop > start, "stop must be larger than start!"
|
||||
|
||||
for i in range(start, stop):
|
||||
idx = f"{i}".zfill(num_digits)
|
||||
logging.info(f"Processing {idx}/{stop - 1}")
|
||||
|
||||
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
|
||||
if cuts_path.is_file():
|
||||
logging.info(f"{cuts_path} exists - skipping")
|
||||
continue
|
||||
|
||||
manifests_path = (
|
||||
output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
|
||||
)
|
||||
if not manifests_path.is_file():
|
||||
logging.info(f"{manifests_path} does not exist - skipping it")
|
||||
continue
|
||||
|
||||
extract_and_save_one_cuts(
|
||||
manifests_path,
|
||||
cuts_path,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
logging.info(vars(args))
|
||||
extract_speech_tokens(args)
|
@ -31,7 +31,9 @@ from icefall.utils import str2bool
|
||||
class TextNormalizer:
|
||||
def __init__(self):
|
||||
self.en_tn_model = EnNormalizer(cache_dir="/tmp/tn", overwrite_cache=False)
|
||||
self.table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||
self.table = str.maketrans(
|
||||
"’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]"
|
||||
)
|
||||
|
||||
def __call__(self, cut):
|
||||
text = cut["supervisions"][0]["custom"]["texts"][0]
|
||||
|
@ -128,4 +128,38 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
--transcript $lang_dir/text
|
||||
fi
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Extract speech tokens."
|
||||
for subset in small medium large; do
|
||||
log "Extract speech tokens for subset: $subset"
|
||||
output_dir=$tokens_dir/libriheavy_${subset}
|
||||
mkdir -p $tokens_dir
|
||||
if [ ! -e $tokens_dir/.extract_completed ]; then
|
||||
torchrun --nproc_per_node=8 \
|
||||
--nnodes=1 \
|
||||
--rdzv_id=2024 \
|
||||
--rdzv_backend="c10d" \
|
||||
--rdzv_endpoint="localhost:0" \
|
||||
`which s3tokenizer` \
|
||||
--cuts_path $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \
|
||||
--device "cuda" \
|
||||
--output_dir $output_dir \
|
||||
--batch_size 32 \
|
||||
--model "speech_tokenizer_v1"
|
||||
cat $output_dir/part* | gzip > $output_dir/libriheavy_${subset}.jsonl.gz && rm -rf $output_dir
|
||||
touch $output_dir/..extract_completed
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Attach speech tokens."
|
||||
for subset in small medium large; do
|
||||
log "Attach speech tokens for subset: $subset"
|
||||
if [ ! -e $tokens_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
|
||||
./local/attach_speech_tokens.py --subset $subset
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
Loading…
x
Reference in New Issue
Block a user