mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
add voicebench eval
This commit is contained in:
parent
89781b9bb1
commit
cbf3af31fd
@ -211,3 +211,49 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then
|
||||
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||
--use-lora True
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then
|
||||
log "stage 12: Decoding EN voicebench"
|
||||
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||
torchrun --nproc_per_node=2 \
|
||||
./qwen_omni/decode_dist.py \
|
||||
--output-dir $exp_dir/log_voicebench \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
|
||||
--use-lora True --subset-name openbookqa --split-name test
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then
|
||||
log "stage 13: Server"
|
||||
exp_dir=./qwen_omni/exp_speech2speech_en_continue
|
||||
python3 ./qwen_omni/server.py \
|
||||
--speech-encoder-path-or-name models/large-v2.pt \
|
||||
--llm-path-or-name models/Qwen2.5-0.5B-Instruct \
|
||||
--checkpoint-path $exp_dir/epoch-10-checkpoint-40000.pt/pytorch_model.bin \
|
||||
--use-flash-attn True \
|
||||
--enable-speech-output True \
|
||||
--use-lora True
|
||||
fi
|
||||
|
||||
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
|
||||
log "stage 14: Client"
|
||||
# datasets=(alpacaeval wildvoice mmsu advbench bbh ifeval commoneval obqa sd-qa)
|
||||
datasets=(alpacaeval_full wildvoice mmsu advbench bbh ifeval sd-qa)
|
||||
for dataset in ${datasets[@]}; do
|
||||
# sd-qa should use usa split
|
||||
if [ $dataset == "sd-qa" ]; then
|
||||
split_name="usa"
|
||||
else
|
||||
split_name="test"
|
||||
fi
|
||||
echo $dataset $split_name
|
||||
python3 ./qwen_omni/client.py \
|
||||
--subset-name $dataset --split-name $split_name \
|
||||
--output-dir test_result
|
||||
done
|
||||
fi
|
||||
|
142
egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py
Normal file
142
egs/speech_llm/SPEECH2SPEECH/qwen_omni/client.py
Normal file
@ -0,0 +1,142 @@
|
||||
# client.py
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Speech-to-Text Client")
|
||||
parser.add_argument(
|
||||
"--server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000",
|
||||
help="URL of the FastAPI server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="hlt-lab/voicebench",
|
||||
help="Hugging Face dataset name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset-name",
|
||||
type=str,
|
||||
default="commoneval", # Adjust as needed
|
||||
help="Dataset subset name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="test", # Adjust as needed
|
||||
help="Dataset split name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=str, help="Directory to save results"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
output_filename = os.path.join(
|
||||
args.output_dir,
|
||||
f"{args.subset_name}-{args.split_name}.jsonl",
|
||||
)
|
||||
server_decode_url = f"{args.server_url}/decode"
|
||||
|
||||
print("Loading dataset...")
|
||||
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.subset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
print(f"Dataset loaded with {len(dataset)} samples.")
|
||||
print(f"Sending requests to {server_decode_url}...")
|
||||
print(f"Saving results to {output_filename}")
|
||||
|
||||
with open(output_filename, "w", encoding="utf-8") as outfile:
|
||||
# Iterate directly over the dataset
|
||||
progress_bar = tqdm(dataset, desc="Processing", unit="samples")
|
||||
for item in progress_bar:
|
||||
|
||||
audio_info = item.get("audio")
|
||||
assert (
|
||||
audio_info["sampling_rate"] == 16000
|
||||
), f"Sampling rate is {audio_info['sampling_rate']}, not 16khz"
|
||||
|
||||
# Prepare data for JSON serialization and server request
|
||||
audio_array = audio_info["array"].tolist() # Convert numpy array to list
|
||||
result_dict = {}
|
||||
for key in item.keys():
|
||||
if key != "audio":
|
||||
# Ensure other fields are JSON serializable
|
||||
try:
|
||||
# Attempt to serialize to catch issues early (optional)
|
||||
json.dumps(item[key])
|
||||
result_dict[key] = item[key]
|
||||
except (TypeError, OverflowError):
|
||||
print(
|
||||
f"Warning: Converting non-serializable key '{key}' to string."
|
||||
)
|
||||
result_dict[key] = str(
|
||||
item[key]
|
||||
) # Convert problematic types to string
|
||||
|
||||
payload = {
|
||||
"audio": audio_array,
|
||||
"sampling_rate": 16000,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(server_decode_url, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
server_response = response.json()
|
||||
decoded_text = server_response.get("text", "")
|
||||
|
||||
# Add the response to the result dictionary
|
||||
result_dict["response"] = decoded_text
|
||||
print(result_dict)
|
||||
# Write result to JSONL file
|
||||
json.dump(result_dict, outfile, ensure_ascii=False)
|
||||
outfile.write("\n")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"\nError sending request for an item: {e}")
|
||||
error_entry = result_dict # Use the data prepared so far
|
||||
error_entry["error"] = str(e)
|
||||
error_entry["response"] = ""
|
||||
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||
outfile.write("\n")
|
||||
except json.JSONDecodeError:
|
||||
print("\nError decoding server response for an item.")
|
||||
error_entry = result_dict
|
||||
error_entry["error"] = "Invalid JSON response from server"
|
||||
error_entry["response"] = ""
|
||||
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||
outfile.write("\n")
|
||||
except Exception as e:
|
||||
print(f"\nUnexpected error processing an item: {e}")
|
||||
error_entry = result_dict
|
||||
error_entry["error"] = f"Unexpected error: {str(e)}"
|
||||
error_entry["response"] = ""
|
||||
json.dump(error_entry, outfile, ensure_ascii=False)
|
||||
outfile.write("\n")
|
||||
|
||||
# Progress bar updates automatically by iterating over tqdm(dataset)
|
||||
|
||||
# No need to close progress_bar explicitly when iterating directly
|
||||
|
||||
print("Processing finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -461,6 +461,15 @@ class AsrDataModule:
|
||||
)
|
||||
return {"test": VoiceAssistant_cuts}
|
||||
|
||||
def test_cuts_voicebench(
|
||||
self,
|
||||
) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
VoiceAssistant_cuts = load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_voice_assistant_small.00000.jsonl.gz"
|
||||
)
|
||||
return {"test": VoiceAssistant_cuts}
|
||||
|
||||
# def train_cuts_en_vocalnet(self) -> CutSet:
|
||||
# logging.info("About to get train cuts")
|
||||
# VoiceAssistant_cuts = load_manifest_lazy(
|
||||
|
256
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py
Normal file
256
egs/speech_llm/SPEECH2SPEECH/qwen_omni/decode_dist.py
Normal file
@ -0,0 +1,256 @@
|
||||
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
||||
# 2025 (authors: Yuekai Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/cli.py
|
||||
""" Example Usage
|
||||
split=test_zh
|
||||
llm_path=f5-tts/exp_zh/checkpoint-805000
|
||||
huggingface-cli download --local-dir f5-tts-small-wenetspeech4tts-basic yuekai/f5-tts-semantic-token-small-wenetspeech4tts-basic
|
||||
model_path=f5-tts-small-wenetspeech4tts-basic/epoch-10-avg-5.pt
|
||||
huggingface-cli download nvidia/bigvgan_v2_24khz_100band_256x --local-dir ./bigvgan_v2_24khz_100band_256x
|
||||
vocoder=./bigvgan_v2_24khz_100band_256x
|
||||
torchrun --nproc_per_node=2 \
|
||||
f5-tts/infer_dist.py \
|
||||
--output_dir $output_dir \
|
||||
--batch_size 1 \
|
||||
--num_workers 2 \
|
||||
--llm-model-name-or-path $llm_path \
|
||||
--flow-matching-model-path $model_path \
|
||||
--decoder-dim 768 --nhead 12 --num-decoder-layers 18 \
|
||||
--use-cosyvoice-semantic-token True \
|
||||
--vocoder-dir $vocoder \
|
||||
--split-name $split -top-k 50 -top-p 0.95 -temperature 0.8 \
|
||||
--tokenizer-dir Qwen/Qwen2.5-0.5B-Instruct
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import whisper
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoTokenizer
|
||||
from web_demo import get_model
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
# https://github.com/FunAudioLLM/CosyVoice/tree/main/third_party
|
||||
# sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="extract speech code")
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="test",
|
||||
help="huggingface dataset split name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset-name",
|
||||
type=str,
|
||||
default="commoneval",
|
||||
help="subset name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=str, help="dir to save result"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=2, help="workers for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefetch", type=int, default=2, help="prefetch for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint name or path, default to %(default)r",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--top-k",
|
||||
# type=int,
|
||||
# default=50,
|
||||
# help="top k for sampling",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--top-p",
|
||||
# type=float,
|
||||
# default=0.95,
|
||||
# help="top p for sampling",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--temperature",
|
||||
# type=float,
|
||||
# default=0.8,
|
||||
# help="temperature for sampling",
|
||||
# )
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group("nccl")
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def preprocess(
|
||||
messages,
|
||||
tokenizer,
|
||||
):
|
||||
"""Preprocesses the data for supervised fine-tuning."""
|
||||
texts = []
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
for i, msg in enumerate(messages):
|
||||
texts.append(
|
||||
tokenizer.apply_chat_template(
|
||||
msg,
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
chat_template=TEMPLATE,
|
||||
padding="longest",
|
||||
truncation=False,
|
||||
)
|
||||
)
|
||||
max_len_texts = max([len(text) for text in texts])
|
||||
if tokenizer.padding_side == "right":
|
||||
texts = [
|
||||
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
||||
for text in texts
|
||||
]
|
||||
else:
|
||||
texts = [
|
||||
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
||||
for text in texts
|
||||
]
|
||||
|
||||
input_ids = torch.tensor(texts, dtype=torch.int)
|
||||
|
||||
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
||||
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def custom_collate(batch):
|
||||
assert len(batch) == 1
|
||||
audio = batch[0]["audio"]
|
||||
assert audio["sampling_rate"] == 16000
|
||||
result = {"audio": audio["array"]}
|
||||
for keys in batch[0].keys():
|
||||
if keys != "audio":
|
||||
result[keys] = batch[0][keys]
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
dataset = load_dataset(
|
||||
"hlt-lab/voicebench",
|
||||
args.subset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
model, tokenizer = get_model(args)
|
||||
# tokenizer = AutoTokenizer.from_pretrained(args.llm_path_or_name)
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=custom_collate,
|
||||
)
|
||||
|
||||
total_steps = len(dataset)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
message = [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
input_ids, attention_mask = preprocess([message], tokenizer)
|
||||
results_jsonl_file = open(
|
||||
os.path.join(
|
||||
args.output_dir,
|
||||
f"results-{args.subset_name}-{args.split_name}-{rank}-audio.jsonl",
|
||||
),
|
||||
"w",
|
||||
)
|
||||
for batch in dataloader:
|
||||
audio = batch["audio"]
|
||||
audio = torch.from_numpy(audio).to(device).to(torch.float32)
|
||||
fbank = whisper.log_mel_spectrogram(audio, device=device)
|
||||
fbank = fbank.unsqueeze(0)
|
||||
generated_ids = model.decode(
|
||||
fbank, input_ids.to(device, dtype=torch.long), attention_mask.to(device)
|
||||
)
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
result_dict = {}
|
||||
for key in batch.keys():
|
||||
if key != "audio":
|
||||
result_dict[key] = batch[key]
|
||||
result_dict["response"] = hyps[0]
|
||||
json.dump(result_dict, results_jsonl_file)
|
||||
results_jsonl_file.write("\n")
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * args.batch_size)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
103
egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py
Normal file
103
egs/speech_llm/SPEECH2SPEECH/qwen_omni/server.py
Normal file
@ -0,0 +1,103 @@
|
||||
# server.py
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
import whisper
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from train import DEFAULT_SPEECH_TOKEN, add_model_arguments
|
||||
from transformers import AutoTokenizer
|
||||
from web_demo import get_model
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="extract speech code")
|
||||
parser.add_argument(
|
||||
"--checkpoint-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Checkpoint name or path, default to %(default)r",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
class SpeechRequest(BaseModel):
|
||||
audio: List[float] # Expecting audio as a list of floats (raw waveform)
|
||||
sampling_rate: int = 16000
|
||||
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
def preprocess_prompt(tokenizer):
|
||||
"""Preprocesses the prompt template."""
|
||||
texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
message, # Using the hardcoded message
|
||||
tokenize=True,
|
||||
add_generation_prompt=False, # Important for generation
|
||||
chat_template=TEMPLATE,
|
||||
padding=False, # No padding needed for single prompt
|
||||
truncation=False,
|
||||
)
|
||||
]
|
||||
input_ids = torch.tensor(texts, dtype=torch.long)
|
||||
attention_mask = torch.ones_like(
|
||||
input_ids, dtype=torch.bool
|
||||
) # Mask is all True for the prompt
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
args = get_args()
|
||||
model, tokenizer = get_model(args)
|
||||
app = FastAPI()
|
||||
|
||||
device = torch.device("cuda")
|
||||
message = [
|
||||
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
||||
prompt_input_ids, prompt_attention_mask = preprocess_prompt(tokenizer)
|
||||
prompt_input_ids = prompt_input_ids.to(device)
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
|
||||
@app.post("/decode", response_model=TextResponse)
|
||||
async def decode_speech(request: SpeechRequest):
|
||||
"""
|
||||
Receives audio waveform, processes it, and returns the decoded text.
|
||||
"""
|
||||
if request.sampling_rate != 16000:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Only 16kHz sampling rate is supported."
|
||||
)
|
||||
|
||||
try:
|
||||
audio_tensor = torch.tensor(request.audio, dtype=torch.float32).to(device)
|
||||
fbank = whisper.log_mel_spectrogram(audio_tensor, device=device, n_mels=80)
|
||||
fbank = fbank.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.decode(fbank, prompt_input_ids, prompt_attention_mask)
|
||||
|
||||
hyps = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
response_text = hyps[0] if hyps else ""
|
||||
|
||||
return TextResponse(text=response_text)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during processing: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Starting server...")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
@ -74,7 +74,8 @@ def get_model(params, device="cuda"):
|
||||
speech_encoder_dim, llm.config.hidden_size, params.encoder_projector_ds_rate
|
||||
)
|
||||
|
||||
codec_vocab_size = 4096 + 4
|
||||
# codec_vocab_size = 4096 + 4
|
||||
codec_vocab_size = 6561 + 4
|
||||
config = Qwen2Config(
|
||||
vocab_size=codec_vocab_size,
|
||||
hidden_size=1024,
|
||||
|
Loading…
x
Reference in New Issue
Block a user