mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
update inference code
This commit is contained in:
parent
f30a52a254
commit
fea972364d
@ -22,6 +22,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
@ -36,6 +37,7 @@ from train import (
|
|||||||
add_model_arguments,
|
add_model_arguments,
|
||||||
get_model,
|
get_model,
|
||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
|
insert_zeros_optimized,
|
||||||
load_F5_TTS_pretrained_checkpoint,
|
load_F5_TTS_pretrained_checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,7 +80,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--manifest-file",
|
"--manifest-file",
|
||||||
type=str,
|
type=str,
|
||||||
default="/path/seed_tts_eval/seedtts_testset/zh/meta.lst",
|
default=None,
|
||||||
help="The manifest file in seed_tts_eval format",
|
help="The manifest file in seed_tts_eval format",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,6 +92,21 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
parser.add_argument("-ss", "--swaysampling", default=-1, type=float)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--insert-zero",
|
||||||
|
action="store_true",
|
||||||
|
help="Insert zeros for CosyVoice",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--split-name",
|
||||||
|
type=str,
|
||||||
|
default="wenetspeech4tts",
|
||||||
|
choices=["wenetspeech4tts", "test_zh", "test_en", "test_hard"],
|
||||||
|
help="huggingface dataset split name",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
@ -243,6 +260,344 @@ def get_inference_prompt(
|
|||||||
return prompts_all
|
return prompts_all
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_prompt_cosy_voice_huggingface(
|
||||||
|
dataset,
|
||||||
|
speed=1.0,
|
||||||
|
tokenizer="pinyin",
|
||||||
|
polyphone=True,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
n_fft=1024,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
hop_length=256,
|
||||||
|
mel_spec_type="bigvgan",
|
||||||
|
target_rms=0.1,
|
||||||
|
use_truth_duration=False,
|
||||||
|
infer_batch_size=1,
|
||||||
|
num_buckets=200,
|
||||||
|
min_secs=3,
|
||||||
|
max_secs=40,
|
||||||
|
insert_zero=False,
|
||||||
|
):
|
||||||
|
prompts_all = []
|
||||||
|
|
||||||
|
min_tokens = min_secs * target_sample_rate // hop_length
|
||||||
|
max_tokens = max_secs * target_sample_rate // hop_length
|
||||||
|
|
||||||
|
batch_accum = [0] * num_buckets
|
||||||
|
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||||
|
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_spectrogram = MelSpec(
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
n_mel_channels=n_mel_channels,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
mel_spec_type=mel_spec_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
utt = dataset[i]["id"]
|
||||||
|
ref_audio_org, ref_sr = (
|
||||||
|
dataset[i]["prompt_audio"]["array"],
|
||||||
|
dataset[i]["prompt_audio"]["sampling_rate"],
|
||||||
|
)
|
||||||
|
ref_audio_org = torch.from_numpy(ref_audio_org).unsqueeze(0).float()
|
||||||
|
audio_tokens = dataset[i]["target_audio_cosy2_tokens"]
|
||||||
|
prompt_audio_tokens = dataset[i]["prompt_audio_cosy2_tokens"]
|
||||||
|
|
||||||
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||||
|
if ref_rms < target_rms:
|
||||||
|
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||||
|
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio_org)
|
||||||
|
else:
|
||||||
|
ref_audio = ref_audio_org
|
||||||
|
input_tokens = prompt_audio_tokens + audio_tokens
|
||||||
|
|
||||||
|
if insert_zero:
|
||||||
|
input_tokens = insert_zeros_optimized(input_tokens)
|
||||||
|
text_list = input_tokens
|
||||||
|
|
||||||
|
# Duration, mel frame length
|
||||||
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||||
|
|
||||||
|
total_mel_len = len(input_tokens)
|
||||||
|
if not insert_zero:
|
||||||
|
total_mel_len = int(total_mel_len / 4 * 15)
|
||||||
|
|
||||||
|
# to mel spectrogram
|
||||||
|
ref_mel = mel_spectrogram(ref_audio)
|
||||||
|
ref_mel = ref_mel.squeeze(0)
|
||||||
|
|
||||||
|
# deal with batch
|
||||||
|
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||||
|
if total_mel_len > max_tokens:
|
||||||
|
print(
|
||||||
|
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
assert (
|
||||||
|
min_tokens <= total_mel_len <= max_tokens
|
||||||
|
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||||
|
bucket_i = math.floor(
|
||||||
|
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
||||||
|
)
|
||||||
|
|
||||||
|
utts[bucket_i].append(utt)
|
||||||
|
ref_rms_list[bucket_i].append(ref_rms)
|
||||||
|
ref_mels[bucket_i].append(ref_mel)
|
||||||
|
ref_mel_lens[bucket_i].append(ref_mel_len)
|
||||||
|
total_mel_lens[bucket_i].append(total_mel_len)
|
||||||
|
# final_text_list[bucket_i].extend(text_list)
|
||||||
|
final_text_list[bucket_i].append(text_list)
|
||||||
|
|
||||||
|
batch_accum[bucket_i] += total_mel_len
|
||||||
|
|
||||||
|
if batch_accum[bucket_i] >= infer_batch_size:
|
||||||
|
prompts_all.append(
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
padded_mel_batch(ref_mels[bucket_i]),
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch_accum[bucket_i] = 0
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
ref_mels[bucket_i],
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# add residual
|
||||||
|
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||||
|
if bucket_frames > 0:
|
||||||
|
prompts_all.append(
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
padded_mel_batch(ref_mels[bucket_i]),
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# not only leave easy work for last workers
|
||||||
|
random.seed(666)
|
||||||
|
random.shuffle(prompts_all)
|
||||||
|
|
||||||
|
return prompts_all
|
||||||
|
|
||||||
|
|
||||||
|
def get_inference_prompt_cosy_voice(
|
||||||
|
metainfo,
|
||||||
|
speed=1.0,
|
||||||
|
tokenizer="pinyin",
|
||||||
|
polyphone=True,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
n_fft=1024,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
hop_length=256,
|
||||||
|
mel_spec_type="bigvgan",
|
||||||
|
target_rms=0.1,
|
||||||
|
use_truth_duration=False,
|
||||||
|
infer_batch_size=1,
|
||||||
|
num_buckets=200,
|
||||||
|
min_secs=3,
|
||||||
|
max_secs=40,
|
||||||
|
insert_zero=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||||
|
sys.path.append("/workspace/CosyVoice")
|
||||||
|
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||||
|
|
||||||
|
cosyvoice = CosyVoice2(
|
||||||
|
"/workspace/CosyVoice2-0.5B", load_jit=False, load_trt=False, fp16=False
|
||||||
|
)
|
||||||
|
prompts_all = []
|
||||||
|
|
||||||
|
min_tokens = min_secs * target_sample_rate // hop_length
|
||||||
|
max_tokens = max_secs * target_sample_rate // hop_length
|
||||||
|
|
||||||
|
batch_accum = [0] * num_buckets
|
||||||
|
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = (
|
||||||
|
[[] for _ in range(num_buckets)] for _ in range(6)
|
||||||
|
)
|
||||||
|
|
||||||
|
mel_spectrogram = MelSpec(
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
n_mel_channels=n_mel_channels,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
mel_spec_type=mel_spec_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(
|
||||||
|
metainfo, desc="Processing prompts..."
|
||||||
|
):
|
||||||
|
# Audio
|
||||||
|
ref_audio_org, ref_sr = torchaudio.load(prompt_wav)
|
||||||
|
|
||||||
|
# cosy voice
|
||||||
|
if ref_sr != 16000:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, 16000)
|
||||||
|
ref_audio_16k = resampler(ref_audio_org)
|
||||||
|
else:
|
||||||
|
ref_audio_16k = ref_audio_org
|
||||||
|
audio_tokens, prompt_audio_tokens = cosyvoice.inference_speech_token(
|
||||||
|
gt_text, prompt_text, ref_audio_16k, stream=False
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio_org)))
|
||||||
|
if ref_rms < target_rms:
|
||||||
|
ref_audio_org = ref_audio_org * target_rms / ref_rms
|
||||||
|
assert (
|
||||||
|
ref_audio_org.shape[-1] > 5000
|
||||||
|
), f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
||||||
|
if ref_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||||
|
ref_audio = resampler(ref_audio_org)
|
||||||
|
else:
|
||||||
|
ref_audio = ref_audio_org
|
||||||
|
|
||||||
|
# Text
|
||||||
|
# if len(prompt_text[-1].encode("utf-8")) == 1:
|
||||||
|
# prompt_text = prompt_text + " "
|
||||||
|
# text = [prompt_text + gt_text]
|
||||||
|
# if tokenizer == "pinyin":
|
||||||
|
# text_list = convert_char_to_pinyin(text, polyphone=polyphone)
|
||||||
|
# else:
|
||||||
|
# text_list = text
|
||||||
|
|
||||||
|
# concat two tensors: prompt audio tokens with audio tokens --> shape 1, prompt_audio_tokens + audio_tokens
|
||||||
|
# prompt_audio_tokens shape 1, prompt_audio_tokens
|
||||||
|
# audio_tokens shape 1, audio_tokens
|
||||||
|
prompt_audio_tokens = prompt_audio_tokens.squeeze().cpu().tolist()
|
||||||
|
input_tokens = prompt_audio_tokens + audio_tokens
|
||||||
|
|
||||||
|
# convert it into a list
|
||||||
|
# input_tokens_list = input_tokens.squeeze().cpu().tolist()
|
||||||
|
if insert_zero:
|
||||||
|
input_tokens = insert_zeros_optimized(input_tokens)
|
||||||
|
text_list = input_tokens
|
||||||
|
|
||||||
|
# Duration, mel frame length
|
||||||
|
ref_mel_len = ref_audio.shape[-1] // hop_length
|
||||||
|
if use_truth_duration:
|
||||||
|
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
||||||
|
if gt_sr != target_sample_rate:
|
||||||
|
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
||||||
|
gt_audio = resampler(gt_audio)
|
||||||
|
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
||||||
|
|
||||||
|
# # test vocoder resynthesis
|
||||||
|
# ref_audio = gt_audio
|
||||||
|
else:
|
||||||
|
ref_text_len = len(prompt_text.encode("utf-8"))
|
||||||
|
gen_text_len = len(gt_text.encode("utf-8"))
|
||||||
|
total_mel_len_compute = ref_mel_len + int(
|
||||||
|
ref_mel_len / ref_text_len * gen_text_len / speed
|
||||||
|
)
|
||||||
|
total_mel_len = len(input_tokens)
|
||||||
|
if not insert_zero:
|
||||||
|
total_mel_len = int(total_mel_len / 4 * 15)
|
||||||
|
print(
|
||||||
|
f"total_mel_len_compute: {total_mel_len_compute}, total_mel_len: {total_mel_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# to mel spectrogram
|
||||||
|
ref_mel = mel_spectrogram(ref_audio)
|
||||||
|
ref_mel = ref_mel.squeeze(0)
|
||||||
|
|
||||||
|
# deal with batch
|
||||||
|
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
||||||
|
assert (
|
||||||
|
min_tokens <= total_mel_len <= max_tokens
|
||||||
|
), f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
||||||
|
bucket_i = math.floor(
|
||||||
|
(total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets
|
||||||
|
)
|
||||||
|
|
||||||
|
utts[bucket_i].append(utt)
|
||||||
|
ref_rms_list[bucket_i].append(ref_rms)
|
||||||
|
ref_mels[bucket_i].append(ref_mel)
|
||||||
|
ref_mel_lens[bucket_i].append(ref_mel_len)
|
||||||
|
total_mel_lens[bucket_i].append(total_mel_len)
|
||||||
|
# final_text_list[bucket_i].extend(text_list)
|
||||||
|
final_text_list[bucket_i].append(text_list)
|
||||||
|
|
||||||
|
batch_accum[bucket_i] += total_mel_len
|
||||||
|
|
||||||
|
if batch_accum[bucket_i] >= infer_batch_size:
|
||||||
|
prompts_all.append(
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
padded_mel_batch(ref_mels[bucket_i]),
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch_accum[bucket_i] = 0
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
ref_mels[bucket_i],
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# add residual
|
||||||
|
for bucket_i, bucket_frames in enumerate(batch_accum):
|
||||||
|
if bucket_frames > 0:
|
||||||
|
prompts_all.append(
|
||||||
|
(
|
||||||
|
utts[bucket_i],
|
||||||
|
ref_rms_list[bucket_i],
|
||||||
|
padded_mel_batch(ref_mels[bucket_i]),
|
||||||
|
ref_mel_lens[bucket_i],
|
||||||
|
total_mel_lens[bucket_i],
|
||||||
|
final_text_list[bucket_i],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# not only leave easy work for last workers
|
||||||
|
random.seed(666)
|
||||||
|
random.shuffle(prompts_all)
|
||||||
|
|
||||||
|
return prompts_all
|
||||||
|
|
||||||
|
|
||||||
def padded_mel_batch(ref_mels):
|
def padded_mel_batch(ref_mels):
|
||||||
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
||||||
padded_ref_mels = []
|
padded_ref_mels = []
|
||||||
@ -275,9 +630,22 @@ def main():
|
|||||||
|
|
||||||
accelerator = Accelerator()
|
accelerator = Accelerator()
|
||||||
device = f"cuda:{accelerator.process_index}"
|
device = f"cuda:{accelerator.process_index}"
|
||||||
|
if args.manifest_file:
|
||||||
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
metainfo = get_seedtts_testset_metainfo(args.manifest_file)
|
||||||
prompts_all = get_inference_prompt(
|
# prompts_all = get_inference_prompt(
|
||||||
|
# metainfo,
|
||||||
|
# speed=1.0,
|
||||||
|
# tokenizer="pinyin",
|
||||||
|
# target_sample_rate=24_000,
|
||||||
|
# n_mel_channels=100,
|
||||||
|
# hop_length=256,
|
||||||
|
# mel_spec_type="bigvgan",
|
||||||
|
# target_rms=0.1,
|
||||||
|
# use_truth_duration=False,
|
||||||
|
# infer_batch_size=1,
|
||||||
|
# )
|
||||||
|
|
||||||
|
prompts_all = get_inference_prompt_cosy_voice(
|
||||||
metainfo,
|
metainfo,
|
||||||
speed=1.0,
|
speed=1.0,
|
||||||
tokenizer="pinyin",
|
tokenizer="pinyin",
|
||||||
@ -288,6 +656,26 @@ def main():
|
|||||||
target_rms=0.1,
|
target_rms=0.1,
|
||||||
use_truth_duration=False,
|
use_truth_duration=False,
|
||||||
infer_batch_size=1,
|
infer_batch_size=1,
|
||||||
|
insert_zero=args.insert_zero,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataset = datasets.load_dataset(
|
||||||
|
"yuekai/seed_tts_cosy2",
|
||||||
|
split=args.split_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
prompts_all = get_inference_prompt_cosy_voice_huggingface(
|
||||||
|
dataset,
|
||||||
|
speed=1.0,
|
||||||
|
tokenizer="pinyin",
|
||||||
|
target_sample_rate=24_000,
|
||||||
|
n_mel_channels=100,
|
||||||
|
hop_length=256,
|
||||||
|
mel_spec_type="bigvgan",
|
||||||
|
target_rms=0.1,
|
||||||
|
use_truth_duration=False,
|
||||||
|
infer_batch_size=1,
|
||||||
|
insert_zero=args.insert_zero,
|
||||||
)
|
)
|
||||||
|
|
||||||
vocoder = BigVGANInference.from_pretrained(
|
vocoder = BigVGANInference.from_pretrained(
|
||||||
@ -324,6 +712,15 @@ def main():
|
|||||||
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
ref_mel_lens = torch.tensor(ref_mel_lens, dtype=torch.long).to(device)
|
||||||
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
total_mel_lens = torch.tensor(total_mel_lens, dtype=torch.long).to(device)
|
||||||
|
|
||||||
|
# concat final_text_list
|
||||||
|
max_len = max([len(tokens) for tokens in final_text_list])
|
||||||
|
# pad tokens to the same length
|
||||||
|
for i, tokens in enumerate(final_text_list):
|
||||||
|
final_text_list[i] = torch.tensor(
|
||||||
|
tokens + [-1] * (max_len - len(tokens)), dtype=torch.long
|
||||||
|
)
|
||||||
|
final_text_list = torch.stack(final_text_list).to(device)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
generated, _ = model.sample(
|
generated, _ = model.sample(
|
||||||
|
@ -580,7 +580,7 @@ def prepare_input(batch: dict, device: torch.device):
|
|||||||
semantic_tokens = []
|
semantic_tokens = []
|
||||||
for i in range(len(batch["tokens"])):
|
for i in range(len(batch["tokens"])):
|
||||||
tokens = batch["tokens"][i]
|
tokens = batch["tokens"][i]
|
||||||
tokens = insert_zeros_optimized(tokens)
|
# tokens = insert_zeros_optimized(tokens)
|
||||||
semantic_tokens.append(tokens)
|
semantic_tokens.append(tokens)
|
||||||
# pad to the same length, B,T, with pad value -1
|
# pad to the same length, B,T, with pad value -1
|
||||||
max_len = max([len(tokens) for tokens in semantic_tokens])
|
max_len = max([len(tokens) for tokens in semantic_tokens])
|
||||||
|
@ -130,9 +130,6 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
data/fbank/${prefix}_cuts_validtest.jsonl.gz \
|
||||||
data/fbank/${prefix}_cuts_test.jsonl.gz
|
data/fbank/${prefix}_cuts_test.jsonl.gz
|
||||||
|
|
||||||
|
|
||||||
# zcat "data/fbank/${prefix}_cuts_${subset}.jsonl.gz" | head -n 100 | gzip > "data/fbank/${prefix}_cuts_${subset}_top100.jsonl.gz"
|
|
||||||
|
|
||||||
rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
rm data/fbank/${prefix}_cuts_validtest.jsonl.gz
|
||||||
|
|
||||||
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
|
n=$(( $(gunzip -c data/fbank/${prefix}_cuts_${subset}.jsonl.gz | wc -l) - 800 ))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user