add extract speech tokens

update prepare.sh

update

update

add attach_speech_tokens
This commit is contained in:
yfyeung 2024-11-04 08:33:02 -08:00
parent 390695bcf3
commit f90c3ae3ec
4 changed files with 119 additions and 159 deletions

View File

@ -0,0 +1,81 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gzip
import json
import logging
import os
from pathlib import Path
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--subset",
type=str,
default="small",
)
return parser.parse_args()
def attach_speech_tokens(args):
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
src_dir = Path("data/manifests")
output_dir = "data/tokens"
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
prefix = "libriheavy"
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
return
manifests_path = src_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
assert manifests_path.is_file(), f"{manifests_path} does not exist!"
tokens_path = output_dir / f"{prefix}_{args.subset}.jsonl.gz"
assert tokens_path.is_file(), f"{tokens_path} does not exist!"
id2tokens = {}
with gzip.open(tokens_path, "r") as fin:
for line in fin:
line = json.loads(line)
id2tokens[line["key"]] = " ".join(map(str, line["code"]))
with gzip.open(manifests_path, "r") as fin, gzip.open(cuts_path, "w") as fout:
for cut in tqdm(fin, desc="Processing"):
cut = json.loads(cut)
if cut["id"] in id2tokens:
cut["custom"] = {"tokens": id2tokens[cut["id"]]}
fout.write((json.dumps(cut) + "\n").encode())
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
attach_speech_tokens(args)

View File

@ -1,157 +0,0 @@
#!/usr/bin/env python3
# Copyright 2024 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from lhotse import CutSet, SupervisionSegment
from lhotse.utils import fastcopy
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--subset",
type=str,
default="small",
)
parser.add_argument(
"--model-path",
type=str,
default="download/hubert_base_ls960.pt",
)
parser.add_argument(
"--start",
type=int,
default=0,
help="Process pieces starting from this number (inclusive).",
)
parser.add_argument(
"--stop",
type=int,
default=-1,
help="Stop processing pieces until this number (exclusive).",
)
return parser.parse_args()
@torch.no_grad()
def extract_and_save_one_cuts(
manifests_path,
cuts_path,
):
logging.info(f"Loading {manifests_path}")
cut_set = CutSet.from_file(manifests_path)
logging.info("Extracting tokens")
cuts = []
tokens = " ".join(map(str, tokens))
cut_with_tokens = fastcopy(
cut,
custom={"tokens": tokens},
)
cuts.append(cut_with_tokens)
cuts = CutSet(cuts)
logging.info(f"Saving to {cuts_path}")
cuts.to_file(cuts_path)
def extract_speech_tokens(args):
assert args.subset in ("small", "medium", "large"), f"{args.subset}"
output_dir = (
f"data/tokens/{args.subset}_split" if args.subset != "small" else "data/tokens"
)
output_dir = Path(output_dir)
assert output_dir.exists(), f"{output_dir} does not exist!"
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
prefix = "libriheavy"
if args.subset == "small":
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
return
manifests_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz"
if not manifests_path.is_file():
logging.info(f"{manifests_path} does not exist - skipping it")
return
extract_and_save_one_cuts(
manifests_path,
cuts_path,
model,
apply_tokens,
do_normalize,
window_duration,
shift_duration,
)
else:
num_digits = 8 # num_digits is fixed by lhotse split-lazy
start = args.start
stop = args.stop
assert stop > start, "stop must be larger than start!"
for i in range(start, stop):
idx = f"{i}".zfill(num_digits)
logging.info(f"Processing {idx}/{stop - 1}")
cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
if cuts_path.is_file():
logging.info(f"{cuts_path} exists - skipping")
continue
manifests_path = (
output_dir / f"{prefix}_cuts_{args.subset}.{idx}.jsonl.gz"
)
if not manifests_path.is_file():
logging.info(f"{manifests_path} does not exist - skipping it")
continue
extract_and_save_one_cuts(
manifests_path,
cuts_path,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(vars(args))
extract_speech_tokens(args)

View File

@ -31,7 +31,9 @@ from icefall.utils import str2bool
class TextNormalizer:
def __init__(self):
self.en_tn_model = EnNormalizer(cache_dir="/tmp/tn", overwrite_cache=False)
self.table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
self.table = str.maketrans(
"’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]"
)
def __call__(self, cut):
text = cut["supervisions"][0]["custom"]["texts"][0]

View File

@ -128,4 +128,38 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
--transcript $lang_dir/text
fi
done
fi
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Extract speech tokens."
for subset in small medium large; do
log "Extract speech tokens for subset: $subset"
output_dir=$tokens_dir/libriheavy_${subset}
mkdir -p $tokens_dir
if [ ! -e $tokens_dir/.extract_completed ]; then
torchrun --nproc_per_node=8 \
--nnodes=1 \
--rdzv_id=2024 \
--rdzv_backend="c10d" \
--rdzv_endpoint="localhost:0" \
`which s3tokenizer` \
--cuts_path $manifests_dir/libriheavy_cuts_${subset}.jsonl.gz \
--device "cuda" \
--output_dir $output_dir \
--batch_size 32 \
--model "speech_tokenizer_v1"
cat $output_dir/part* | gzip > $output_dir/libriheavy_${subset}.jsonl.gz && rm -rf $output_dir
touch $output_dir/..extract_completed
fi
done
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Attach speech tokens."
for subset in small medium large; do
log "Attach speech tokens for subset: $subset"
if [ ! -e $tokens_dir/libriheavy_cuts_${subset}.jsonl.gz ]; then
./local/attach_speech_tokens.py --subset $subset
fi
done
fi